Legacy API changes (#441)

* initial commit

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* another initial commit

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* another initial commit

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* one more initial commit

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* next step

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* next step

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* next step

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* next step

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Refactored buffer() and shapeInfo() methods usage with NDArray class.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt Graph class methods to use const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt choose op to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt where op shape method to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt lstsq op to use constant empty shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt matrix_diag_part op shape routine to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt determinant ops to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt mean_pairwssqerr_loss ops to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt ops shape methods.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt shape methods for loss ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt log_loss op shape method.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt shape methods for ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt dilation2d ops shape methods.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted deconv2d ops shape methods.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted dynamicRNN op shape method.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted shape methods for ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted shape methods for lstm layer ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* few updates

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* first cuda tweak

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Adopt constant shapes for sconv2d ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt constant shapes for gru ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt constant shapes with shape methods for segment ops and so on.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted constant shapes with unsorted_segment_* ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted constant shapes with gamma op shape method.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted shape methods of reduce_stddev ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted shape methods for reduce_* ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt shape method for squeeze op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt strided_slice shape method.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored concat op shape method to adopt constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted shape method for mirror_pad op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted split op shape method.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted tile ops shape methods.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added const cast for mkldnn routines handles.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored logSoftMaxForVector_ routine to conform with proper data and shape pointer casts.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Cosmetic changes to proper usage of constant pointers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored a couple shape comparators for strides and addBias helpers to proper use data pointers with inplace option.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored depthToSpace helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored histogram helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored im2col helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored gather and gatherND helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage on percentile helper.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed gather shape with helpers and range buffer usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with space to depth helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage and constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with LUP decomposition>

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored onehot_ helper.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored pad and prefix to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactoed softmax helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed space to batch helpers to use buffers properly.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed stack and split helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with sparse to dense helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with mindistance_ helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with tile helper.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed constant shape usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed constant shape usage with legacy pairwise bool ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored a couple of methods to adopt constant shape usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed broadcasting with constant shape."

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed const usage with inplace reverse and constant shapes with legacy reduction.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored legacy ops with const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored sort to adopt constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected sort for constant shape usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed constant shape usage with special methods.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored Context to conform with constant shape usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* CUDA broadcasting headers

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* pairwise/indexreduce/random headers

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Refactored native ops to adopt constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* legacy reduce3/scalar headers

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Corrected pullRow signature and tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected routines to proper use of constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored tests to use constant shapes properly.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored legacy ops tests to use constant shapes properly.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored buffer usage with NDArray tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed native ops tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed special concat routine.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with test.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with a test.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored TAD.h and tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored calcStrides* routines to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed miscelaneous errors with constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* NativeOps const changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Corrected definitions for declared functions.

Signed-off-by: shugeo <sgazeos@gmail.com>

* NativeOps const changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* few more const changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Fixed const shapes with shape routines.

Signed-off-by: shugeo <sgazeos@gmail.com>

* few more const changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Fixed shape method for broadcastable case.

Signed-off-by: shugeo <sgazeos@gmail.com>

* few more const changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* xw_plus_b BP shape fn restored

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Fixed signatures with broadcasting.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Repaired backprops shape methods for a set of operations.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored broadcast bool for cuda.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored methods for 3 args with const qualifier.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed a couple of kernel signatures for broadcasting.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed kernels signatures for const buffers and shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored pairwise methods to persistent buffers and shapes usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt const to buffers and shapes with kernels.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt const to buffers and shapes with scalar kernels.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored indexreduce kernels signatures to use const buffers and shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored pairwise kernels to adopt cons shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored pairwise bool kernels to adopt cons shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored random special ops to conform with const shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored native ops to conform with const shapes and buffers under cuda platform.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Cosmetical changes only.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed const shapes and buffers error.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected start pos routine.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored methods to conform with const shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored helpers to use proper methods instead.

Signed-off-by: shugeo <sgazeos@gmail.com>

* bunch of changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* next bunch of changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* next bunch of changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Fixed execScalar declaration.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed execScalar declaration.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected const shape cases with sort and so on.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed const shapes for sort.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored kernel declarations to adopt const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed kernels declarations to adopt const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected kernel declarations to adopt const shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed kernels declarations to adopt const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed segment helpers kernels declarations and so on to adopt const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed const shape usage with segment and solve helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed kernel declaration with adjustWeight helper.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed cuda implementations for constant shape helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted const shape usage with kernels.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted top_k kernels to use const shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected kernels declarations to adopt const shapes with helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored NDArray definitions to adopt const shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed const shapes with image suppression helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Slight improvement with buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored buffer usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored buffer usage with tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed const shape usage with definitions.

Signed-off-by: shugeo <sgazeos@gmail.com>

* minor updates on cpu side

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Refactored const shape usage with ConstantDescritor and native ops with cuda platform.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored tear and tile kernels to adopt with const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* softmax_loop fix

Signed-off-by: raver119 <raver119@gmail.com>

* update missing signature

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* softmax again

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* few more missing consts

Signed-off-by: raver119 <raver119@gmail.com>

* new methods updated

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

Co-authored-by: shugeo <sgazeos@gmail.com>
master
raver119 2020-05-09 08:06:14 +03:00 committed by GitHub
parent 0613485654
commit 320924278d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
470 changed files with 7521 additions and 7468 deletions

View File

@ -35,7 +35,7 @@ namespace sd {
std::vector<double> _floatValues;
public:
ConstantDescriptor(double* values, int length);
ConstantDescriptor(Nd4jLong* values, int length);
ConstantDescriptor(Nd4jLong const* values, int length);
ConstantDescriptor(std::initializer_list<double> values);
explicit ConstantDescriptor(std::vector<Nd4jLong> &values);

View File

@ -125,7 +125,7 @@ namespace sd {
void templatedDoubleAssign(void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const;
template <typename T, typename R>
FORCEINLINE R templatedGet(void *buffer, const Nd4jLong index) const;
FORCEINLINE R templatedGet(void const* buffer, const Nd4jLong index) const;
/*
template <typename T, typename R>
R templatedGetIndex(void *buffer, Nd4jLong *indices) const;
@ -193,7 +193,7 @@ namespace sd {
#ifndef __JAVACPP_HACK__
NDArray(std::shared_ptr<DataBuffer> buffer, const ShapeDescriptor& descriptor, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const Nd4jLong offset = 0);
NDArray(std::shared_ptr<DataBuffer> buffer, const char order, const std::vector<Nd4jLong> &shape, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
NDArray(std::shared_ptr<DataBuffer> buffer, char order, const std::vector<Nd4jLong> &shape, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
/**
* This contructors create scalar array containing string utf8
@ -250,13 +250,14 @@ namespace sd {
/**
* do not allocate memory, memory for array is passed from outside
*/
NDArray(void *buffer, Nd4jLong* shapeInfo, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
NDArray(void *buffer, Nd4jLong* shapeInfo, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool isBuffAlloc = false);
NDArray(void *buffer, const Nd4jLong* shapeInfo, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool isBuffAlloc = false);
/**
* do not allocate memory, memory for array is passed from outside
* we suppose the content of both (device and host) buffers is identical
*/
NDArray(void *buffer, void *bufferD, Nd4jLong* shapeInfo, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool isBuffAlloc = false, const bool isBuffDAlloc = false);
NDArray(void *buffer, void *bufferD, const Nd4jLong* shapeInfo, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool isBuffAlloc = false, bool isBuffDAlloc = false);
/**
* copy constructor
@ -277,28 +278,28 @@ namespace sd {
/**
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
*/
NDArray(Nd4jLong* shapeInfo, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool nullify = true);
NDArray(const Nd4jLong* shapeInfo, bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool nullify = true);
/**
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
* set dtype as array type
*/
NDArray(Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool nullify = true);
NDArray(const Nd4jLong* shapeInfo, sd::DataType dtype, bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool nullify = true);
/**
* this constructor creates new array using shape information contained in vector argument
*/
NDArray(const char order, const std::vector<Nd4jLong> &shape, sd::DataType dtype = DOUBLE, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
NDArray(char order, const std::vector<Nd4jLong> &shape, sd::DataType dtype = DOUBLE, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
/**
* This constructor creates new array with elements copied from data and using shape information stored in shape, elements from data will be casted to dtype
*/
NDArray(const char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, sd::DataType dtype = DOUBLE, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
NDArray(char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, sd::DataType dtype = DOUBLE, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
/**
* this constructor creates new array using given buffer (without memory allocation) and shape information stored in shape
*/
NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, sd::DataType dtype, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
NDArray(void *buffer, char order, const std::vector<Nd4jLong> &shape, sd::DataType dtype, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
/**
* This method returns new array with the same shape & data type
@ -317,12 +318,12 @@ namespace sd {
* this constructor creates new NDArray with shape matching "other" array,
* doesn't copy "other" elements into new array !!!
*/
explicit NDArray(const NDArray* other, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext ::defaultContext());
explicit NDArray(const NDArray* other, bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext ::defaultContext());
/**
* this constructor creates scalar(and set its value = 0) or empty array depending on bool argument isScalar
*/
NDArray(sd::DataType dtype, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool isScalar = true);
NDArray(sd::DataType dtype, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool isScalar = true);
/**
* This method blocks until asynchronous operation finishes
@ -364,9 +365,11 @@ namespace sd {
* @param offset
* @return
*/
void *bufferWithOffset(Nd4jLong offset) const;
void const* bufferWithOffset(Nd4jLong offset) const;
void* bufferWithOffset(Nd4jLong offset);
void* specialBufferWithOffset(Nd4jLong offset) const;
void const* specialBufferWithOffset(Nd4jLong offset) const;
void* specialBufferWithOffset(Nd4jLong offset);
/**
* copy assignment operator
* in particular, when _dataType != other._dataType and both shapes are the same, there will be allocation of new _buffer and _dataType acquires other._dataType
@ -450,38 +453,39 @@ namespace sd {
/**
* returns host buffer
*/
FORCEINLINE void* getBuffer() const;
FORCEINLINE void* buffer();
FORCEINLINE const void* buffer() const;
/**
* returns buffer offset (offset is the same for host and device buffers)
*/
FORCEINLINE Nd4jLong getBufferOffset() const;
FORCEINLINE Nd4jLong bufferOffset();
FORCEINLINE Nd4jLong bufferOffset() const;
/**
* if _bufferD==nullptr return _buffer, else return _bufferD
*/
void* specialBuffer();
void* getSpecialBuffer() const;
const void* specialBuffer() const;
/**
* returns device buffer if compilation is for cuda case, otherwise returns host buffer
*/
void* getPlatformBuffer() const;
void* platformBuffer();
const void* platformBuffer() const;
template <typename T>
T* bufferAsT() const;
T* bufferAsT();
template <typename T>
const T* bufferAsT() const;
/**
* returns _shapeInfo
*/
FORCEINLINE Nd4jLong* shapeInfo();
FORCEINLINE Nd4jLong* getShapeInfo() const;
FORCEINLINE const Nd4jLong* shapeInfo() const;
/**
@ -493,12 +497,9 @@ namespace sd {
/**
* if _shapeInfoD==nullptr return _shapeInfo, else return _shapeInfoD
*/
FORCEINLINE Nd4jLong* specialShapeInfo();
FORCEINLINE Nd4jLong* getSpecialShapeInfo() const;
FORCEINLINE const Nd4jLong* specialShapeInfo() const;
Nd4jLong* platformShapeInfo();
Nd4jLong* getPlatformShapeInfo() const;
const Nd4jLong* platformShapeInfo() const;
/**
* permutes (in-place) the dimensions in array according to "dimensions" array
@ -1509,8 +1510,8 @@ bool NDArray::isAttached() {
}
template <typename T, typename R>
FORCEINLINE R NDArray::templatedGet(void *buffer, Nd4jLong index) const {
auto b = reinterpret_cast<T*>(buffer);
FORCEINLINE R NDArray::templatedGet(void const* buffer, Nd4jLong index) const {
auto b = reinterpret_cast<T const*>(buffer);
auto v = static_cast<R>(b[index]);
return v;
}
@ -1625,9 +1626,9 @@ bool NDArray::nonNull() const {
return true;
if(!Environment::getInstance()->isCPU())
return getDataBuffer()->special() != nullptr && getSpecialShapeInfo() != nullptr;
return getDataBuffer()->special() != nullptr && specialShapeInfo() != nullptr;
return getDataBuffer()->primary() != nullptr && getShapeInfo() != nullptr;
return getDataBuffer()->primary() != nullptr && shapeInfo() != nullptr;
}
//////////////////////////////////////////////////////////////////////////
@ -1744,7 +1745,7 @@ bool NDArray::isEmpty() const {
if (this->_shapeInfo == nullptr)
return false;
return ArrayOptions::arrayType(this->getShapeInfo()) == ArrayType::EMPTY;
return ArrayOptions::arrayType(this->shapeInfo()) == ArrayType::EMPTY;
}
//////////////////////////////////////////////////////////////////////////
@ -1804,7 +1805,7 @@ T& NDArray::t(const Nd4jLong i, const Nd4jLong j) {
syncToHost();
Nd4jLong coords[2] = {i, j};
auto offset = shape::getOffset(getShapeInfo(), coords);
auto offset = shape::getOffset(shapeInfo(), coords);
tickWriteHost();
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
}
@ -1821,7 +1822,7 @@ T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) {
syncToHost();
Nd4jLong coords[3] = {i, j, k};
auto offset = shape::getOffset(getShapeInfo(), coords);
auto offset = shape::getOffset(shapeInfo(), coords);
tickWriteHost();
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
}
@ -1838,7 +1839,7 @@ T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLo
syncToHost();
Nd4jLong coords[4] = {i, j, k, w};
auto offset = shape::getOffset(getShapeInfo(), coords);
auto offset = shape::getOffset(shapeInfo(), coords);
tickWriteHost();
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
}
@ -1856,7 +1857,7 @@ T NDArray::t(const Nd4jLong i) const {
syncToHost();
tickReadHost();
return *(reinterpret_cast<T*>(bufferWithOffset(getOffset(i))));
return *(reinterpret_cast<const T*>(bufferWithOffset(getOffset(i))));
}
////////////////////////////////////////////////////////////////////////
@ -1872,9 +1873,9 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const {
syncToHost();
Nd4jLong coords[2] = {i, j};
auto offset = shape::getOffset(getShapeInfo(), coords);
auto offset = shape::getOffset(shapeInfo(), coords);
tickReadHost();
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
return *(reinterpret_cast<const T*>(bufferWithOffset(offset)));
}
template <typename T>
@ -1889,9 +1890,9 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const {
syncToHost();
Nd4jLong coords[3] = {i, j, k};
auto offset = shape::getOffset(getShapeInfo(), coords);
auto offset = shape::getOffset(shapeInfo(), coords);
tickReadHost();
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
return *(reinterpret_cast<const T*>(bufferWithOffset(offset)));
}
template <typename T>
@ -1906,9 +1907,9 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const {
syncToHost();
Nd4jLong coords[4] = {i, j, k, w};
auto offset = shape::getOffset(getShapeInfo(), coords);
auto offset = shape::getOffset(shapeInfo(), coords);
tickReadHost();
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
return *(reinterpret_cast<const T*>(bufferWithOffset(offset)));
}
#ifndef __JAVACPP_HACK__
@ -1924,8 +1925,7 @@ std::shared_ptr<DataBuffer> NDArray::dataBuffer() {
#endif
////////////////////////////////////////////////////////////////////////
void* NDArray::getBuffer() const {
const void* NDArray::buffer() const {
return _buffer->primary() != nullptr ? static_cast<int8_t*>(_buffer->primary()) + (_offset * sizeOfT()) : nullptr;
}
@ -1934,18 +1934,13 @@ void* NDArray::buffer() {
return _buffer->primary() != nullptr ? static_cast<int8_t*>(_buffer->primary()) + (_offset * sizeOfT()) : nullptr;
}
////////////////////////////////////////////////////////////////////////
Nd4jLong* NDArray::getShapeInfo() const {
return _shapeInfo;
}
//////////////////////////////////////////////////////////////////////////
Nd4jLong* NDArray::shapeInfo() {
const Nd4jLong* NDArray::shapeInfo() const {
return _shapeInfo;
}
////////////////////////////////////////////////////////////////////////
Nd4jLong* NDArray::specialShapeInfo() {
const Nd4jLong* NDArray::specialShapeInfo() const {
if (_shapeInfoD == nullptr)
return _shapeInfo;
// FIXME: this should be fixed once CUDA backend added
@ -1953,23 +1948,10 @@ Nd4jLong* NDArray::specialShapeInfo() {
}
////////////////////////////////////////////////////////////////////////
Nd4jLong NDArray::getBufferOffset() const {
Nd4jLong NDArray::bufferOffset() const {
return _offset;
}
////////////////////////////////////////////////////////////////////////
Nd4jLong NDArray::bufferOffset() {
return _offset;
}
////////////////////////////////////////////////////////////////////////
Nd4jLong* NDArray::getSpecialShapeInfo() const{
if (_shapeInfoD == nullptr)
return _shapeInfo;
// FIXME: this should be fixed once CUDA backend added
return _shapeInfoD;
}
#if defined(__CUDACC__) //&& defined(BUILD_TESTS)
// for CUDA we need stil stuff inline

File diff suppressed because it is too large Load Diff

View File

@ -23,26 +23,26 @@
#include <cuda.h>
#include <cuda_runtime.h>
static Nd4jLong __device__ __noinline__ getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) {
static Nd4jLong __device__ __noinline__ getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo) {
return shape::getIndexOffset(index, shapeInfo);
}
static Nd4jLong __device__ __noinline__ length(Nd4jLong *shapeInfo) {
static Nd4jLong __device__ __noinline__ length(const Nd4jLong *shapeInfo) {
return shape::length(shapeInfo);
}
template <typename T, typename Lambda> static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda);
template <typename T, typename Lambda> static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda);
template <typename T, typename Lambda> static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda);
template <typename T, typename Lambda> static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda);
template <typename T, typename Lambda> static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda);
template <typename T, typename Lambda> static _CUDA_G void lambdaKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda);
template <typename T, typename Lambda> static _CUDA_G void lambdaIndexedKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda);
template <typename T, typename Lambda> static _CUDA_G void lambdaIndexedPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda);
template <typename T, typename Lambda> static _CUDA_G void lambdaPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda);
template <typename T, typename Lambda> static _CUDA_G void lambdaTriplewiseKernel(const void* vw, const Nd4jLong *wShapeInfo, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda);
template <typename T>
class LambdaHelper {
public:
template <typename Lambda>
FORCEINLINE static void lambdaLauncher(cudaStream_t *stream, void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) {
FORCEINLINE static void lambdaLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
lambdaKernel<T, Lambda><<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, lambda);
auto err = cudaStreamSynchronize(*stream);
if (err != 0)
@ -50,7 +50,7 @@ public:
}
template <typename Lambda>
FORCEINLINE static void lambdaIndexedLauncher(cudaStream_t *stream, void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) {
FORCEINLINE static void lambdaIndexedLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
lambdaIndexedKernel<T, Lambda><<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, lambda);
auto err = cudaStreamSynchronize(*stream);
if (err != 0)
@ -58,7 +58,7 @@ public:
}
template <typename Lambda>
FORCEINLINE static void lambdaPairwiseLauncher(cudaStream_t *stream, void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) {
FORCEINLINE static void lambdaPairwiseLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
lambdaPairwiseKernel<T, Lambda><<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda);
auto err = cudaStreamSynchronize(*stream);
if (err != 0)
@ -66,7 +66,7 @@ public:
}
template <typename Lambda>
FORCEINLINE static void lambdaIndexedPairwiseLauncher(cudaStream_t *stream, void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) {
FORCEINLINE static void lambdaIndexedPairwiseLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
lambdaIndexedPairwiseKernel<T, Lambda><<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda);
auto err = cudaStreamSynchronize(*stream);
if (err != 0)
@ -74,7 +74,7 @@ public:
}
template <typename Lambda>
FORCEINLINE static void lambdaTriplewiseLauncher(cudaStream_t *stream, void* vw, Nd4jLong *wShapeInfo, void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) {
FORCEINLINE static void lambdaTriplewiseLauncher(cudaStream_t *stream,const void* vw, const Nd4jLong *wShapeInfo, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
lambdaTriplewiseKernel<T, Lambda><<<256, 512, 1024, *stream>>>(vw, wShapeInfo, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda);
auto err = cudaStreamSynchronize(*stream);
if (err != 0)
@ -84,8 +84,8 @@ public:
////////////////////////////////////////////////////////////////////////
template <typename T, typename Lambda>
static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) {
auto x = reinterpret_cast<T*>(vx);
static _CUDA_G void lambdaKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
auto x = reinterpret_cast<const T*>(vx);
auto z = reinterpret_cast<T*>(vz);
auto xEws = shape::elementWiseStride(xShapeInfo);
@ -113,8 +113,8 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL
////////////////////////////////////////////////////////////////////////
template <typename T, typename Lambda>
static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) {
auto x = reinterpret_cast<T*>(vx);
static _CUDA_G void lambdaIndexedKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
auto x = reinterpret_cast<const T*>(vx);
auto z = reinterpret_cast<T*>(vz);
auto xEws = shape::elementWiseStride(xShapeInfo);
@ -142,9 +142,9 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz
////////////////////////////////////////////////////////////////////////
template <typename T, typename Lambda>
static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) {
auto x = reinterpret_cast<T*>(vx);
auto y = reinterpret_cast<T*>(vy);
static _CUDA_G void lambdaIndexedPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
auto x = reinterpret_cast<const T*>(vx);
auto y = reinterpret_cast<const T*>(vy);
auto z = reinterpret_cast<T*>(vz);
auto xEws = shape::elementWiseStride(xShapeInfo);
@ -175,9 +175,9 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo,
////////////////////////////////////////////////////////////////////////
template <typename T, typename Lambda>
static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) {
auto x = reinterpret_cast<T*>(vx);
auto y = reinterpret_cast<T*>(vy);
static _CUDA_G void lambdaPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
auto x = reinterpret_cast<const T*>(vx);
auto y = reinterpret_cast<const T*>(vy);
auto z = reinterpret_cast<T*>(vz);
auto xEws = shape::elementWiseStride(xShapeInfo);
@ -208,10 +208,10 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v
////////////////////////////////////////////////////////////////////////
template <typename T, typename Lambda>
static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) {
auto w = reinterpret_cast<T*>(vw);
auto x = reinterpret_cast<T*>(vx);
auto y = reinterpret_cast<T*>(vy);
static _CUDA_G void lambdaTriplewiseKernel(const void* vw, const Nd4jLong *wShapeInfo, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
auto w = reinterpret_cast<const T*>(vw);
auto x = reinterpret_cast<const T*>(vx);
auto y = reinterpret_cast<const T*>(vy);
auto z = reinterpret_cast<T*>(vz);
auto wEws = shape::elementWiseStride(wShapeInfo);
@ -271,7 +271,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, Lambda func, NDArray& ta
//throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, target.dataType());
prepareSpecialUse({&target}, {this, &other});
BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES);
registerSpecialUse({&target}, {this, &other});
}
@ -298,7 +298,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, Lambda func, NDArray& t
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda X/Y/Z data types must be the same");
prepareSpecialUse({&target}, {this, &other});
BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES);
registerSpecialUse({&target}, {this, &other});
}

View File

@ -28,26 +28,24 @@
namespace sd {
class ND4J_EXPORT ShapeList {
protected:
std::vector<Nd4jLong*> _shapes;
std::vector<const Nd4jLong*> _shapes;
bool _destroyed = false;
bool _autoremovable = false;
bool _workspace = false;
public:
ShapeList(Nd4jLong* shape = nullptr);
ShapeList(std::initializer_list<Nd4jLong*> shapes);
ShapeList(std::initializer_list<Nd4jLong*> shapes, bool isWorkspace);
ShapeList(std::vector<Nd4jLong*>& shapes);
ShapeList(const Nd4jLong* shape = nullptr);
ShapeList(const std::vector<const Nd4jLong*> &shapes, bool isWorkspace);
ShapeList(const std::vector<const Nd4jLong*>& shapes);
//ShapeList(bool autoRemovable);
~ShapeList();
std::vector<Nd4jLong*>* asVector();
std::vector<const Nd4jLong*>* asVector();
void destroy();
int size();
Nd4jLong* at(int idx);
void push_back(Nd4jLong *shape);
void push_back(std::vector<Nd4jLong>& shape);
int size() const;
const Nd4jLong* at(int idx);
void push_back(const Nd4jLong *shape);
/**
* PLEASE NOTE: This method should be called ONLY if shapes were generated at workspaces. Otherwise you'll get memory leak

View File

@ -28,18 +28,18 @@ namespace sd {
private:
ConstantDataBuffer _tadShape;
ConstantDataBuffer _tadOffsets;
Nd4jLong _numTads;
int _shapeInfoLength;
Nd4jLong _numTads = 0 ;
int _shapeInfoLength = 0;
public:
explicit TadPack(ConstantDataBuffer &shapes, ConstantDataBuffer &offets, Nd4jLong numTads);
TadPack() = default;
~TadPack() = default;
Nd4jLong* primaryShapeInfo() const;
Nd4jLong* primaryOffsets() const;
const Nd4jLong* primaryShapeInfo() const;
const Nd4jLong* primaryOffsets() const;
Nd4jLong* specialShapeInfo() const;
Nd4jLong* specialOffsets() const;
const Nd4jLong* specialShapeInfo() const;
const Nd4jLong* specialOffsets() const;
Nd4jLong numberOfTads() const;
int shapeInfoLength() const;
@ -48,8 +48,8 @@ namespace sd {
* These methods return either primary or special pointers depending on platform binaries were compiled for
* @return
*/
Nd4jLong *platformShapeInfo() const;
Nd4jLong *platformOffsets() const;
const Nd4jLong *platformShapeInfo() const;
const Nd4jLong *platformOffsets() const;
};
}

View File

@ -52,10 +52,9 @@ namespace sd {
////////////////////////////////////////////////////////////////////////
void* NDArray::platformBuffer() { return buffer(); }
void* NDArray::getPlatformBuffer() const { return getBuffer(); }
void const* NDArray::platformBuffer() const { return buffer(); }
Nd4jLong* NDArray::getPlatformShapeInfo() const { return getShapeInfo(); }
Nd4jLong* NDArray::platformShapeInfo() { return shapeInfo(); }
Nd4jLong const* NDArray::platformShapeInfo() const { return shapeInfo(); }
void NDArray::syncToDevice() const { }
void NDArray::syncToHost() const { }
@ -85,15 +84,15 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& t
upper = target.sizeAt(-1);
const T value = static_cast<T>(val);
const auto x = reinterpret_cast<const T*>(getBuffer());
auto z = reinterpret_cast<T*>(target.getBuffer());
const auto x = reinterpret_cast<const T*>(buffer());
auto z = reinterpret_cast<T*>(target.buffer());
const int xRank = rankOf();
const int zRank = target.rankOf();
const auto zLen = target.lengthOf();
const bool areSameOffsets = shape::haveSameShapeAndStrides(getShapeInfo(), target.getShapeInfo());
const bool areSameOffsets = shape::haveSameShapeAndStrides(shapeInfo(), target.shapeInfo());
auto func = PRAGMA_THREADS_FOR {
@ -101,8 +100,8 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& t
for (auto i = start; i < stop; i++) {
shape::index2coordsCPU(start, i, target.getShapeInfo(), coords);
const auto zOffset = shape::getOffset(target.getShapeInfo(), coords);
shape::index2coordsCPU(start, i, target.shapeInfo(), coords);
const auto zOffset = shape::getOffset(target.shapeInfo(), coords);
// if( (row + upper < col) || (row + lower > col) )
if ((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1]))
@ -113,7 +112,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& t
coords[0] = coords[1];
}
const auto xOffset = areSameOffsets ? zOffset : shape::getOffset(getShapeInfo(), coords);
const auto xOffset = areSameOffsets ? zOffset : shape::getOffset(shapeInfo(), coords);
z[zOffset] = x[xOffset];
if (xRank != zRank) // restore first coordinate
@ -140,7 +139,7 @@ void NDArray::setIdentity() {
for(int j = 0; j < rank; ++j)
indices[j] = 1;
Nd4jLong offset = shape::getOffset(getShapeInfo(), indices);
Nd4jLong offset = shape::getOffset(shapeInfo(), indices);
for(int i = 0; i < rank; ++i)
if(minDim > shape[i])
@ -214,23 +213,28 @@ void NDArray::printCurrentBuffer(const bool host, const char* msg, const int pre
}
////////////////////////////////////////////////////////////////////////
void* NDArray::specialBufferWithOffset(Nd4jLong offset) {
return nullptr;
}
////////////////////////////////////////////////////////////////////////
void* NDArray::specialBufferWithOffset(Nd4jLong offset) const {
const void* NDArray::specialBufferWithOffset(Nd4jLong offset) const {
return nullptr;
}
////////////////////////////////////////////////////////////////////////
void* NDArray::specialBuffer() {
if (_buffer->special() == nullptr)
return getBuffer();
return buffer();
// FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
}
////////////////////////////////////////////////////////////////////////
void* NDArray::getSpecialBuffer() const {
void const* NDArray::specialBuffer() const {
if (_buffer->special() == nullptr)
return getBuffer();
return buffer();
// FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
}
@ -253,7 +257,7 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
NDArray result(*this);
if(diff < 0) { // reshape to higher dimension
std::vector<Nd4jLong> shapeNew = reps; // there is requirement to have unities at first "diff" positions of new shape
memcpy(&shapeNew[-diff], result.getShapeInfo()+1, rankOld * sizeof(Nd4jLong)); // put old shape numbers at rest of positions
memcpy(&shapeNew[-diff], result.shapeInfo()+1, rankOld * sizeof(Nd4jLong)); // put old shape numbers at rest of positions
result.reshapei(ordering(), shapeNew);
}
return result; // nothing to do, if diff >= 0 -> identity tile
@ -274,8 +278,8 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++) {
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo());
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), i, this->getBuffer(), yOffset), LIBND4J_TYPES);
auto yOffset = shape::subArrayOffset(i, newShapeInfo, shapeInfo());
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.buffer(), i, this->buffer(), yOffset), LIBND4J_TYPES);
}
};
@ -286,8 +290,8 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++) {
auto xOffset = result.getOffset(i);
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo());
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), xOffset, this->getBuffer(), yOffset), LIBND4J_TYPES);
auto yOffset = shape::subArrayOffset(i, newShapeInfo, shapeInfo());
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.buffer(), xOffset, this->buffer(), yOffset), LIBND4J_TYPES);
}
};
@ -307,7 +311,7 @@ void NDArray::tile(const std::vector<Nd4jLong>& reps, NDArray& target) const {
// evaluate true tile shapeInfo for comparison with target shapeInfo
auto newShapeInfo = ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace());
if(!shape::equalsSoft(newShapeInfo, target.getShapeInfo())) {
if(!shape::equalsSoft(newShapeInfo, target.shapeInfo())) {
delete []newShapeInfo;
throw std::runtime_error("NDArray::tile method - shapeInfo of target array is not suitable for tile operation !");
}
@ -319,14 +323,14 @@ void NDArray::tile(const std::vector<Nd4jLong>& reps, NDArray& target) const {
if(target.ordering() == 'c' && ews == 1) { // ews == 1 always here
//#pragma omp parallel for simd if(targetLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)
for(Nd4jLong i=0; i<targetLen; ++i) {
auto yOffset = shape::subArrayOffset(i, target.getShapeInfo(), getShapeInfo());
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.getBuffer(), i, getBuffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES);
auto yOffset = shape::subArrayOffset(i, target.shapeInfo(), shapeInfo());
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.buffer(), i, buffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES);
}
}
else if(target.ordering() == 'c' && ews > 1) {
for(Nd4jLong i=0; i<targetLen; ++i) {
auto yOffset = shape::subArrayOffset(i, target.getShapeInfo(), getShapeInfo());
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.getBuffer(), i*ews, getBuffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES);
auto yOffset = shape::subArrayOffset(i, target.shapeInfo(), shapeInfo());
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.buffer(), i*ews, buffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES);
}
}
else {
@ -334,8 +338,8 @@ void NDArray::tile(const std::vector<Nd4jLong>& reps, NDArray& target) const {
for(Nd4jLong i=0; i<targetLen; ++i) {
auto xOffset = target.getOffset(i);
auto yOffset = shape::subArrayOffset(i, target.getShapeInfo(), getShapeInfo());
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.getBuffer(), xOffset, getBuffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES);
auto yOffset = shape::subArrayOffset(i, target.shapeInfo(), shapeInfo());
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.buffer(), xOffset, buffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES);
}
}
}
@ -355,8 +359,8 @@ void NDArray::tile(NDArray& target) const {
if(target.ordering() == 'c' && ews >= 1) {
for(Nd4jLong i=0; i<targetLen; ++i) {
auto yOffset = shape::subArrayOffset(i, target.getShapeInfo(), getShapeInfo());
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.getBuffer(), i*ews, getBuffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES);
auto yOffset = shape::subArrayOffset(i, target.shapeInfo(), shapeInfo());
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.buffer(), i*ews, buffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES);
}
}
else {
@ -364,8 +368,8 @@ void NDArray::tile(NDArray& target) const {
for(Nd4jLong i=0; i<targetLen; ++i) {
auto xOffset = target.getOffset(i);
auto yOffset = shape::subArrayOffset(i, target.getShapeInfo(), getShapeInfo());
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.getBuffer(), xOffset, getBuffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES);
auto yOffset = shape::subArrayOffset(i, target.shapeInfo(), shapeInfo());
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.buffer(), xOffset, buffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES);
}
}
}
@ -388,8 +392,8 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
for (auto i = start; i < stop; i++) {
shape::index2coordsCPU(start, i, output.getShapeInfo(), coords);
const auto zOffset = shape::getOffset(output.getShapeInfo(), coords);
shape::index2coordsCPU(start, i, output.shapeInfo(), coords);
const auto zOffset = shape::getOffset(output.shapeInfo(), coords);
temp = coords[axis];
@ -404,7 +408,7 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
} else
coords[axis] /= repeats[0];
z[zOffset] = x[shape::getOffset(input.getShapeInfo(), coords)];
z[zOffset] = x[shape::getOffset(input.shapeInfo(), coords)];
coords[axis] = temp;
}

View File

@ -50,16 +50,16 @@
namespace sd {
void* NDArray::platformBuffer() { return specialBuffer(); }
void* NDArray::getPlatformBuffer() const { return getSpecialBuffer(); }
void const* NDArray::platformBuffer() const { return specialBuffer(); }
Nd4jLong* NDArray::getPlatformShapeInfo() const { return getSpecialShapeInfo(); }
Nd4jLong* NDArray::platformShapeInfo() { return specialShapeInfo(); }
Nd4jLong const* NDArray::platformShapeInfo() const { return specialShapeInfo(); }
//Nd4jLong const* NDArray::platformShapeInfo() { return specialShapeInfo(); }
void NDArray::syncToDevice() const {
auto currentDeviceId = AffinityManager::currentDeviceId();
if (currentDeviceId != _deviceId) {
// first of all we update shapeInfo
const_cast<NDArray*>(this)->setShapeInfo(this->getShapeInfo());
const_cast<NDArray*>(this)->setShapeInfo(this->shapeInfo());
// now we actually migrate data buffer
_buffer->migrate();
@ -142,7 +142,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& t
PointersManager manager(getContext(), "NDArray::fillAsTriangular");
NDArray::prepareSpecialUse({&target}, {this});
fillAsTriangularCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *getContext()->getCudaStream()>>>(getPlatformBuffer(), getPlatformShapeInfo(), target.getPlatformBuffer(), target.getPlatformShapeInfo(), static_cast<T>(val), lower, upper);
fillAsTriangularCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *getContext()->getCudaStream()>>>(platformBuffer(), platformShapeInfo(), target.platformBuffer(), target.platformShapeInfo(), static_cast<T>(val), lower, upper);
NDArray::registerSpecialUse({&target}, {this});
manager.synchronize();
@ -206,7 +206,7 @@ void NDArray::setIdentity() {
PointersManager manager(getContext(), "NDArray::setIdentity");
syncToDevice();
BUILD_SINGLE_SELECTOR(dataType(), identityMatrixCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), getPlatformBuffer(), getPlatformShapeInfo(), 1.f), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(dataType(), identityMatrixCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), platformBuffer(), platformShapeInfo(), 1.f), LIBND4J_TYPES);
tickWriteDevice();
manager.synchronize();
@ -293,12 +293,16 @@ void NDArray::registerPrimaryUse(const std::vector<const NDArray*>& writeList, c
//////////////////////////////////////////////////////////////////////////
void NDArray::syncShape() const {
cudaMemcpy(getSpecialShapeInfo(), getShapeInfo(), shape::shapeInfoByteLength(getShapeInfo()), cudaMemcpyHostToDevice);
cudaMemcpy(const_cast<Nd4jLong*>(specialShapeInfo()), shapeInfo(), shape::shapeInfoByteLength(shapeInfo()), cudaMemcpyHostToDevice);
}
//////////////////////////////////////////////////////////////////////////
void* NDArray::specialBufferWithOffset(Nd4jLong offset) const {
return getSpecialBuffer() != nullptr ? static_cast<int8_t*>(getSpecialBuffer()) + (offset * sizeOfT()) : nullptr;
void const* NDArray::specialBufferWithOffset(Nd4jLong offset) const {
return specialBuffer() != nullptr ? static_cast<int8_t const*>(specialBuffer()) + (offset * sizeOfT()) : nullptr;
}
void* NDArray::specialBufferWithOffset(Nd4jLong offset){
return specialBuffer() != nullptr ? static_cast<int8_t*>(specialBuffer()) + (offset * sizeOfT()) : nullptr;
}
//////////////////////////////////////////////////////////////////////////
@ -318,7 +322,7 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
NDArray result(*this);
if(diff < 0) { // reshape to higher dimension
std::vector<Nd4jLong> shapeNew = reps; // need to have unities at first "diff" positions of new shape
memcpy(&shapeNew[-diff], result.getShapeInfo()+1, rankOld * sizeof(Nd4jLong)); // put old shape numbers at rest of positions
memcpy(&shapeNew[-diff], result.shapeInfo()+1, rankOld * sizeof(Nd4jLong)); // put old shape numbers at rest of positions
result.reshapei(ordering(), shapeNew);
}
return result; // nothing to do, if diff >= 0 -> identity tile
@ -332,13 +336,13 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
NDArray result(newBuff, ShapeDescriptor(newShapeInfo), getContext());
// fill newBuff, loop through all elements of newBuff
// looping through getBuffer() goes automatically by means of getSubArrayIndex applying
// looping through buffer() goes automatically by means of getSubArrayIndex applying
const auto resultLen = result.lengthOf();
auto xType = this->dataType();
auto stream = getContext()->getCudaStream();
prepareSpecialUse({&result}, {this});
BUILD_SINGLE_SELECTOR(xType, tileKernelH, (this->getSpecialBuffer(), this->getSpecialShapeInfo(), result.getSpecialBuffer(), result.getSpecialShapeInfo(), resultLen, stream), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(xType, tileKernelH, (this->specialBuffer(), this->specialShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), resultLen, stream), LIBND4J_TYPES);
registerSpecialUse({&result}, {this});
return result;
@ -354,18 +358,18 @@ void NDArray::tile(const std::vector<Nd4jLong>& reps, NDArray& target) const {
// evaluate true tile shapeInfo for comparison with target shapeInfo
auto newShapeInfo = ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace());
if(!shape::equalsSoft(newShapeInfo, target.getShapeInfo())) {
if(!shape::equalsSoft(newShapeInfo, target.shapeInfo())) {
throw std::runtime_error("NDArray::tile method - shapeInfo of target array is not suitable for tile operation !");
}
// fill newBuff, loop through all elements of newBuff
// looping through getBuffer() goes automatically by means of getSubArrayIndex applying
// looping through buffer() goes automatically by means of getSubArrayIndex applying
const int ews = target.ews();
const int targetLen = target.lengthOf();
auto stream = getContext()->getCudaStream();
prepareSpecialUse({&target}, {this});
BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (specialBuffer(), specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES);
registerSpecialUse({&target}, {this});
}
@ -384,7 +388,7 @@ void NDArray::tile(NDArray& target) const {
auto stream = getContext()->getCudaStream();
prepareSpecialUse({&target}, {this});
BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (specialBuffer(), specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES);
registerSpecialUse({&target}, {this});
}
@ -467,7 +471,7 @@ NDArray NDArray::repeat(const int axis, const std::vector<int>& repeats) const {
const int* reps = reinterpret_cast<int*>(manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int)));
prepareSpecialUse({&output}, {this});
BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), getSpecialBuffer(), getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), specialBuffer(), specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES);
prepareSpecialUse({&output}, {this});
manager.synchronize();
@ -491,7 +495,7 @@ void NDArray::repeat(const int axis, const std::vector<int>& repeats, NDArray& t
const int* reps = reinterpret_cast<int*>(manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int)));
prepareSpecialUse({&target}, {this});
BUILD_DOUBLE_SELECTOR(dataType(), target.dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), getSpecialBuffer(), getSpecialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES, LIBND4J_TYPES);
BUILD_DOUBLE_SELECTOR(dataType(), target.dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), specialBuffer(), specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES, LIBND4J_TYPES);
prepareSpecialUse({&target}, {this});
manager.synchronize();
@ -501,16 +505,20 @@ void NDArray::repeat(const int axis, const std::vector<int>& repeats, NDArray& t
////////////////////////////////////////////////////////////////////////
void* NDArray::specialBuffer() {
if (_buffer->special() == nullptr)
return getBuffer();
if (_buffer->special() == nullptr) {
syncToDevice();
tickReadHost();
}
// FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
}
////////////////////////////////////////////////////////////////////////
void* NDArray::getSpecialBuffer() const {
if (_buffer->special() == nullptr)
return getBuffer();
void const* NDArray::specialBuffer() const {
if (_buffer->special() == nullptr) {
syncToDevice();
tickReadHost();
}
// FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
}
@ -526,7 +534,7 @@ void NDArray::printCurrentBuffer(const bool host, const char* msg, const int pre
printf("%s", msg);
if(host) {
if(getBuffer() == nullptr || _length == 0)
if(buffer() == nullptr || _length == 0)
{ printf("NDArray::printActualBuffer: host buffer is nullptr !\n"); return; }
const T* buff = bufferAsT<T>();
@ -535,7 +543,7 @@ void NDArray::printCurrentBuffer(const bool host, const char* msg, const int pre
printf("\n");
}
else {
if(getSpecialBuffer() == nullptr || _length == 0)
if(specialBuffer() == nullptr || _length == 0)
{ printf("NDArray::printSpecialBuffer: special buffer is nullptr !\n"); return; }
void* pHost = operator new(sizeof(T) * _length);
@ -545,7 +553,7 @@ void NDArray::printCurrentBuffer(const bool host, const char* msg, const int pre
cudaMemcpyAsync(reinterpret_cast<T*>(pHost) + i, specialBufferWithOffset(i), sizeof(T), cudaMemcpyDeviceToHost, *(getContext()->getCudaStream()));
}
else
cudaMemcpyAsync(pHost, getSpecialBuffer(), sizeOfT() * _length, cudaMemcpyDeviceToHost, *getContext()->getCudaStream());
cudaMemcpyAsync(pHost, specialBuffer(), sizeOfT() * _length, cudaMemcpyDeviceToHost, *getContext()->getCudaStream());
cudaError_t cudaResult = cudaStreamSynchronize(*getContext()->getCudaStream());
if(cudaResult != 0)

View File

@ -28,7 +28,7 @@ namespace sd {
_floatValues.emplace_back(values[e]);
}
ConstantDescriptor::ConstantDescriptor(Nd4jLong * values, int length) {
ConstantDescriptor::ConstantDescriptor(Nd4jLong const* values, int length) {
for (int e = 0; e < length; e++)
_integerValues.emplace_back(values[e]);
}

View File

@ -417,7 +417,7 @@ NDArray NDArrayFactory::create(const std::vector<T> &values, sd::LaunchContext *
NDArray res(buffer, ShapeDescriptor::vectorDescriptor(values.size(), DataTypeUtils::fromT<T>()), context);
memcpyFromVector<T>(res.getBuffer(), values);
memcpyFromVector<T>(res.buffer(), values);
res.tickWriteHost();
res.syncToDevice();

View File

@ -153,7 +153,7 @@ namespace sd {
inputs[e] = _chunks[e];
}
auto inShapeInfo = inputs[0]->getShapeInfo();
auto inShapeInfo = inputs[0]->shapeInfo();
int rank = shape::rank(inShapeInfo);
NDArray* array = nullptr;

View File

@ -26,7 +26,7 @@ namespace sd {
// _autoremovable = autoRemovable;
// }
ShapeList::ShapeList(Nd4jLong* shape) {
ShapeList::ShapeList(const Nd4jLong* shape) {
if (shape != nullptr)
_shapes.push_back(shape);
}
@ -36,21 +36,15 @@ namespace sd {
destroy();
}
ShapeList::ShapeList(std::initializer_list<Nd4jLong*> shapes) {
for (auto v:shapes)
_shapes.push_back(v);
}
ShapeList::ShapeList(std::initializer_list<Nd4jLong*> shapes, bool isWorkspace) : ShapeList(shapes){
ShapeList::ShapeList(const std::vector<const Nd4jLong*> &shapes, bool isWorkspace) : ShapeList(shapes){
_workspace = isWorkspace;
}
ShapeList::ShapeList(std::vector<Nd4jLong*>& shapes) {
for (auto v:shapes)
_shapes.push_back(v);
ShapeList::ShapeList(const std::vector<const Nd4jLong*>& shapes) {
_shapes = shapes;
}
std::vector<Nd4jLong*>* ShapeList::asVector() {
std::vector<const Nd4jLong*>* ShapeList::asVector() {
return &_shapes;
}
@ -66,33 +60,21 @@ namespace sd {
_destroyed = true;
}
int ShapeList::size() {
int ShapeList::size() const {
return (int) _shapes.size();
}
Nd4jLong* ShapeList::at(int idx) {
const Nd4jLong* ShapeList::at(int idx) {
if (_shapes.size() <= idx)
throw std::runtime_error("Can't find requested variable by index");
return _shapes.at(idx);
}
void ShapeList::push_back(Nd4jLong *shape) {
void ShapeList::push_back(const Nd4jLong *shape) {
_shapes.push_back(shape);
}
void ShapeList::push_back(std::vector<Nd4jLong>& shape) {
int dLen = shape::shapeInfoLength(shape.at(0));
if (shape.size() != dLen)
throw std::runtime_error("Bad shape was passed in");
auto nShape = new Nd4jLong[dLen];
std::memcpy(nShape, shape.data(), shape::shapeInfoByteLength(shape.at(0)));
_shapes.push_back(nShape);
}
void ShapeList::detach() {
for (int e = 0; e < _shapes.size(); e++) {
_shapes[e] = shape::detachShape(_shapes[e]);

View File

@ -29,18 +29,19 @@ namespace sd {
_numTads = numTads;
}
Nd4jLong* TadPack::primaryShapeInfo() const {
const Nd4jLong* TadPack::primaryShapeInfo() const {
return reinterpret_cast<Nd4jLong *>(_tadShape.primary());
}
Nd4jLong* TadPack::primaryOffsets() const {
const Nd4jLong* TadPack::primaryOffsets() const {
return reinterpret_cast<Nd4jLong *>(_tadOffsets.primary());
}
Nd4jLong* TadPack::specialShapeInfo() const {
const Nd4jLong* TadPack::specialShapeInfo() const {
return reinterpret_cast<Nd4jLong *>(_tadShape.special());
}
Nd4jLong* TadPack::specialOffsets() const {
const Nd4jLong* TadPack::specialOffsets() const {
return reinterpret_cast<Nd4jLong *>(_tadOffsets.special());
}
@ -48,11 +49,11 @@ namespace sd {
return _numTads;
}
Nd4jLong* TadPack::platformShapeInfo() const {
const Nd4jLong* TadPack::platformShapeInfo() const {
return sd::Environment::getInstance()->isCPU() ? primaryShapeInfo() : specialShapeInfo();
}
Nd4jLong* TadPack::platformOffsets() const {
const Nd4jLong* TadPack::platformOffsets() const {
return sd::Environment::getInstance()->isCPU() ? primaryOffsets() : specialOffsets();
}

View File

@ -196,12 +196,14 @@ namespace sd {
#endif
void setInputArray(int index, NDArray *array, bool removable = false);
void setInputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
void setInputArray(int index, void *databuffer, void *shapeInfo, void *specialShapeInfo);
void setInputArray(int index, void *buffer, void const* shapeInfo, void *specialBuffer, void const* specialShapeInfo);
void setInputArray(int index, void *buffer, void * shapeInfo, void *specialBuffer, void * specialShapeInfo);
void setInputArray(int index, void *databuffer, void const* shapeInfo, void const* specialShapeInfo);
void setOutputArray(int index, NDArray *array, bool removable = false);
void setOutputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
void setOutputArray(int index, void *databuffer, void *shapeInfo, void *specialShapeInfo);
void setOutputArray(int index, void *buffer, const void * shapeInfo, void *specialBuffer, const void * specialShapeInfo);
void setOutputArray(int index, void *buffer, void * shapeInfo, void *specialBuffer, void * specialShapeInfo);
void setOutputArray(int index, void *databuffer, void const* shapeInfo, void const* specialShapeInfo);
void setTArguments(double *arguments, int numberOfArguments);
void setIArguments(Nd4jLong *arguments, int numberOfArguments);

View File

@ -407,8 +407,12 @@ namespace sd {
_handles.emplace_back(array);
}
void Context::setInputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) {
auto array = new NDArray(buffer, specialBuffer, reinterpret_cast<Nd4jLong *>(shapeInfo));
void Context::setInputArray(int index, void *buffer, void * shapeInfo, void *specialBuffer, void * specialShapeInfo) {
this->setInputArray(index, buffer, const_cast<const void*>(shapeInfo), specialBuffer, const_cast<const void *>(specialShapeInfo));
}
void Context::setInputArray(int index, void *buffer, void const* shapeInfo, void *specialBuffer, void const* specialShapeInfo) {
auto array = new NDArray(buffer, specialBuffer, reinterpret_cast<Nd4jLong const*>(shapeInfo));
if (_fastpath_in.size() < index + 1)
_fastpath_in.resize(index+1);
@ -430,11 +434,15 @@ namespace sd {
_handles.emplace_back(array);
}
void Context::setOutputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) {
void Context::setOutputArray(int index, void *buffer, void * shapeInfo, void *specialBuffer, void * specialShapeInfo) {
this->setOutputArray(index, buffer, const_cast<const void *>(shapeInfo), specialBuffer, const_cast<const void *>(specialShapeInfo));
}
void Context::setOutputArray(int index, void *buffer, const void * shapeInfo, void *specialBuffer, const void * specialShapeInfo) {
if (_fastpath_out.size() < index + 1)
_fastpath_out.resize(index+1);
auto array = new NDArray(buffer, specialBuffer, reinterpret_cast<Nd4jLong *>(shapeInfo));
auto array = new NDArray(buffer, specialBuffer, reinterpret_cast<Nd4jLong const*>(shapeInfo));
_fastpath_out[index] = array;
_handles.emplace_back(array);
@ -443,7 +451,7 @@ namespace sd {
array->setContext(_context);
}
void Context::setInputArray(int index, void *vdatabuffer, void *shapeInfo, void *specialShapeInfo) {
void Context::setInputArray(int index, void *vdatabuffer, void const* shapeInfo, void const* specialShapeInfo) {
auto dataBuffer = reinterpret_cast<InteropDataBuffer*>(vdatabuffer);
if (_fastpath_in.size() < index + 1)
@ -451,9 +459,9 @@ namespace sd {
NDArray *array;
if (dataBuffer != nullptr)
array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast<Nd4jLong *>(shapeInfo), sd::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast<Nd4jLong *>(shapeInfo))));
array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast<Nd4jLong const*>(shapeInfo), sd::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast<Nd4jLong const*>(shapeInfo))));
else
array = new NDArray(nullptr, nullptr, reinterpret_cast<Nd4jLong *>(shapeInfo));
array = new NDArray(nullptr, nullptr, reinterpret_cast<Nd4jLong const*>(shapeInfo));
_fastpath_in[index] = array;
_handles.emplace_back(array);
@ -462,7 +470,7 @@ namespace sd {
array->setContext(_context);
}
void Context::setOutputArray(int index, void *vdatabuffer, void *shapeInfo, void *specialShapeInfo) {
void Context::setOutputArray(int index, void *vdatabuffer, void const* shapeInfo, void const* specialShapeInfo) {
auto dataBuffer = reinterpret_cast<InteropDataBuffer*>(vdatabuffer);
if (_fastpath_out.size() < index + 1)
@ -470,9 +478,9 @@ namespace sd {
NDArray *array;
if (dataBuffer != nullptr)
array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast<Nd4jLong *>(shapeInfo), sd::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast<Nd4jLong *>(shapeInfo))));
array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast<Nd4jLong const*>(shapeInfo), sd::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast<Nd4jLong const*>(shapeInfo))));
else
array = new NDArray(nullptr, nullptr, reinterpret_cast<Nd4jLong *>(shapeInfo));
array = new NDArray(nullptr, nullptr, reinterpret_cast<Nd4jLong const*>(shapeInfo));
_fastpath_out[index] = array;
_handles.emplace_back(array);

View File

@ -50,8 +50,8 @@ namespace sd {
Nd4jLong result = 0L;
Nd4jLong lastStep = 0L;
std::vector<Nd4jLong *> shapes;
MAP_IMPL<std::pair<int,int>, Nd4jLong*> shapesMap;
std::vector<Nd4jLong const*> shapes;
MAP_IMPL<std::pair<int,int>, Nd4jLong const*> shapesMap;
int cntFD = 0;
@ -83,12 +83,12 @@ namespace sd {
auto in = node->input()->at(0);
auto block = node->getContextPrototype();
std::vector<Nd4jLong*> inputShapes;
std::vector<Nd4jLong const*> inputShapes;
int *oldShape;
for (auto v: *node->input()) {
nd4j_debug(" inputs for estimation are: %i:%i\n", v.first, v.second);
if (v.first < 0) {
inputShapes.push_back(_variableSpace->getVariable(v.first)->getNDArray()->getShapeInfo());
inputShapes.push_back(_variableSpace->getVariable(v.first)->getNDArray()->shapeInfo());
} else {
inputShapes.push_back(shapesMap.at(v));
}
@ -102,7 +102,7 @@ namespace sd {
int cnt = 0;
for (auto newShape: *outSha->asVector()) {
std::pair<int, int> pairAddr(node->id(), cnt++);
std::pair<std::pair<int, int>, Nd4jLong*> pairShape(pairAddr, newShape);
std::pair<std::pair<int, int>, Nd4jLong const*> pairShape(pairAddr, newShape);
shapesMap.insert(pairShape);
@ -122,11 +122,11 @@ namespace sd {
auto x = _variableSpace->getVariable(in);
auto z = _variableSpace->getVariable(node->id());
auto newShape = new Nd4jLong[shape::shapeInfoLength(x->getNDArray()->getShapeInfo())];
memcpy(newShape, x->getNDArray()->getShapeInfo(), shape::shapeInfoByteLength(x->getNDArray()->getShapeInfo()));
auto newShape = new Nd4jLong[shape::shapeInfoLength(x->getNDArray()->shapeInfo())];
memcpy(newShape, x->getNDArray()->shapeInfo(), shape::shapeInfoByteLength(x->getNDArray()->shapeInfo()));
std::pair<int, int> pairAddr(node->id(), 0);
std::pair<std::pair<int, int>, Nd4jLong*> pairShape(pairAddr, newShape);
std::pair<std::pair<int, int>, Nd4jLong const*> pairShape(pairAddr, newShape);
shapesMap.insert(pairShape);
@ -141,7 +141,7 @@ namespace sd {
memcpy(newShape, prevShape, shape::shapeInfoByteLength(prevShape));
std::pair<int, int> pairAddr(node->id(), 0);
std::pair<std::pair<int, int>, Nd4jLong*> pairShape(pairAddr, newShape);
std::pair<std::pair<int, int>, Nd4jLong const*> pairShape(pairAddr, newShape);
shapesMap.insert(pairShape);
@ -152,30 +152,30 @@ namespace sd {
}
} else if (node->getOpClass() == OpClass_REDUCTION) {
Nd4jLong *newShape = nullptr;
Nd4jLong const* newShape = nullptr;
// if that's scalar output - we don't care about previous node
if (node->getDimensions()->size() == 0 || (node->getDimensions()->size() == 1 && node->getDimensions()->at(0) == sd::DataTypeUtils::max<int>())) {
newShape = new Nd4jLong[8];
newShape[0] = 2;
newShape[1] = 1;
newShape[2] = 1;
newShape[3] = 1;
newShape[4] = 1;
newShape[5] = 8192; // set type as FLOAT32 by default
newShape[6] = 1;
newShape[7] = 99;
// auto aNewShape = new Nd4jLong[8];
//
// aNewShape[0] = 2;
// aNewShape[1] = 1;
// aNewShape[2] = 1;
// aNewShape[3] = 1;
// aNewShape[4] = 1;
// aNewShape[5] = 8192; // set type as FLOAT32 by default
// aNewShape[6] = 1;
// aNewShape[7] = 99;
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(DataType::FLOAT32, 'c', {1,1});
} else {
auto in = node->input()->at(0);
Nd4jLong *oldShape = nullptr;
Nd4jLong const* oldShape = nullptr;
// calculate tads here
if (in.first < 0) {
auto x = _variableSpace->getVariable(in)->getNDArray();
oldShape = x->getShapeInfo();
oldShape = x->shapeInfo();
} else {
oldShape = shapesMap.at(in);
@ -188,7 +188,7 @@ namespace sd {
}
std::pair<int, int> pairAddr(node->id(), 0);
std::pair<std::pair<int, int>, Nd4jLong*> pairShape(pairAddr, newShape);
std::pair<std::pair<int, int>, Nd4jLong const*> pairShape(pairAddr, newShape);
shapesMap.insert(pairShape);

View File

@ -88,8 +88,8 @@ namespace sd {
void setObjectsSize(Nd4jLong bytes);
void setTotalSize(Nd4jLong bytes);
void addInputShape(Nd4jLong *shapeInfo);
void addOutputShape(Nd4jLong *shapeInfo);
void addInputShape(Nd4jLong const* shapeInfo);
void addOutputShape(Nd4jLong const* shapeInfo);
Nd4jLong getActivationsSize() const;
Nd4jLong getTemporarySize() const;

View File

@ -116,11 +116,11 @@ namespace sd {
return _executionTime;
}
void NodeProfile::addInputShape(Nd4jLong *shapeInfo) {
void NodeProfile::addInputShape(Nd4jLong const* shapeInfo) {
_inputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo));
}
void NodeProfile::addOutputShape(Nd4jLong *shapeInfo) {
void NodeProfile::addOutputShape(Nd4jLong const*shapeInfo) {
_outputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo));
}

View File

@ -51,20 +51,20 @@ namespace sd {
ConstantDataBuffer bufferForShapeInfo(sd::DataType dataType, char order, const std::vector<Nd4jLong> &shape);
ConstantDataBuffer bufferForShapeInfo(const ShapeDescriptor &descriptor);
ConstantDataBuffer bufferForShapeInfo(const Nd4jLong *shapeInfo);
ConstantDataBuffer bufferForShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape);
ConstantDataBuffer createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace = nullptr, const std::vector<int> dimensions = {});
ConstantDataBuffer bufferForShapeInfo(sd::DataType dataType, char order, int rank, const Nd4jLong* shape);
ConstantDataBuffer createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace = nullptr, const std::vector<int> &dimensions = {});
Nd4jLong* emptyShapeInfo(const sd::DataType dataType);
Nd4jLong* scalarShapeInfo(const sd::DataType dataType);
Nd4jLong* vectorShapeInfo(const Nd4jLong length, const sd::DataType dataType);
Nd4jLong* createShapeInfo(const ShapeDescriptor &descriptor);
Nd4jLong* createShapeInfo(const sd::DataType dataType, const char order, const std::vector<Nd4jLong> &shape);
Nd4jLong* createShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape);
Nd4jLong* createShapeInfo(const sd::DataType dataType, const Nd4jLong* shapeInfo);
const Nd4jLong* emptyShapeInfo(sd::DataType dataType);
const Nd4jLong* scalarShapeInfo(sd::DataType dataType);
const Nd4jLong* vectorShapeInfo(Nd4jLong length, sd::DataType dataType);
const Nd4jLong* createShapeInfo(const ShapeDescriptor &descriptor);
const Nd4jLong* createShapeInfo(sd::DataType dataType, char order, const std::vector<Nd4jLong> &shape);
const Nd4jLong* createShapeInfo(sd::DataType dataType, char order, int rank, const Nd4jLong* shape);
const Nd4jLong* createShapeInfo(sd::DataType dataType, const Nd4jLong* shapeInfo);
Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace);
Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal = true);
const Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace);
const Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal = true);
bool checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor);

View File

@ -41,43 +41,43 @@ namespace sd {
public:
template <typename OpType>
static FORCEINLINE void loopReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, E* extraParams, int64_t start, int64_t stop);
static FORCEINLINE void loopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, E* extraParams, int64_t start, int64_t stop);
};
template <typename X, typename Z>
class ReductionFloatLoops : public ReductionLoops<X, Z, Z> {
public:
static void wrapper(const int opNum, X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop);
static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop);
template <typename OpType>
static void innerloopReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop);
static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop);
};
template <typename X, typename Z>
class ND4J_EXPORT ReductionBoolLoops : public ReductionLoops<X, Z, X> {
public:
static void wrapper(const int opNum, X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
template <typename OpType>
static void innerloopReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
};
template <typename X, typename Z>
class ND4J_EXPORT ReductionLongLoops : public ReductionLoops<X, Z, X> {
public:
static void wrapper(const int opNum, X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
template <typename OpType>
static void innerloopReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
};
template <typename X>
class ND4J_EXPORT ReductionSameLoops : public ReductionLoops<X, X, X> {
public:
static void wrapper(const int opNum, X* x, Nd4jLong* xShapeInfo, X* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
template <typename OpType>
static void innerloopReduce(X* x, Nd4jLong* xShapeInfo, X* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
};
@ -85,10 +85,10 @@ namespace sd {
class ND4J_EXPORT IndexReductionLoops {
private:
public:
static void wrapIndexReduce(const int opNum, void* x, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* extraParams);
static void wrapIndexReduce(int opNum, const void* x, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* extraParams);
template <typename OpType>
static void loopIndexReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams);
static void loopIndexReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams);
};
@ -98,7 +98,7 @@ namespace sd {
public:
template<typename OpType>
static FORCEINLINE void loopTransform(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, E* extraParams, uint64_t threadId, uint64_t numThreads);
static FORCEINLINE void loopTransform(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, E* extraParams, uint64_t threadId, uint64_t numThreads);
};
template <typename X, typename Z>
@ -106,20 +106,20 @@ namespace sd {
public:
template <typename OpType>
static FORCEINLINE void loopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop);
static FORCEINLINE void loopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop);
template <typename OpType>
static FORCEINLINE void loopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop);
static FORCEINLINE void loopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop);
static void wrapper(const int opNum, X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop);
static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop);
static void wrapperAll(const int opNum, X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop);
static void wrapperAll(int opNum, const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop);
template <typename OpType>
static void innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop);
static void innerloopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop);
template <typename OpType>
static void innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop);
static void innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop);
};
@ -263,10 +263,11 @@ namespace sd {
//////////////////////////////////////////////////////////////////////////////
template<typename X, typename Z, typename E>
template <typename OpType>
void sd::ReductionLoops<X, Z, E>::loopReduce(X* x, Nd4jLong* xShapeInfo,
Z* z, Nd4jLong* zShapeInfo,
Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets,
E* extraParams, int64_t start, int64_t stop) {
void sd::ReductionLoops<X, Z, E>::loopReduce(const X* x, const Nd4jLong* xShapeInfo,
Z* z, const Nd4jLong* zShapeInfo,
const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets,
E* extraParams,
int64_t start, int64_t stop) {
const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopTadXZ(xShapeInfo, zShapeInfo, tadShapeInfo);
@ -492,9 +493,10 @@ namespace sd {
//////////////////////////////////////////////////////////////////////////////
template <typename X, typename Z, typename E>
template <typename OpType>
void sd::TransformLoops<X, Z, E>::loopTransform(X* x, Nd4jLong* xShapeInfo,
Z* z, Nd4jLong* zShapeInfo,
E* extraParams, uint64_t threadId, uint64_t numThreads) {
void sd::TransformLoops<X, Z, E>::loopTransform(const X* x, const Nd4jLong* xShapeInfo,
Z* z, const Nd4jLong* zShapeInfo,
E* extraParams,
uint64_t threadId, uint64_t numThreads) {
const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo);
@ -682,11 +684,11 @@ namespace sd {
//////////////////////////////////////////////////////////////////////////////
template<typename X, typename Z>
template <typename OpType>
void sd::Reduction3Loops<X, Z>::loopReduce3(X* x, Nd4jLong* xShapeInfo,
X* y, Nd4jLong* yShapeInfo,
Z* z, Nd4jLong* zShapeInfo,
int* dims, int dimsLen,
Z* extraParameters, int64_t start, int64_t stop) {
void sd::Reduction3Loops<X, Z>::loopReduce3(const X* x, const Nd4jLong* xShapeInfo,
const X* y, const Nd4jLong* yShapeInfo,
Z* z, const Nd4jLong* zShapeInfo,
int* dims, int dimsLen,
Z* extraParameters, int64_t start, int64_t stop) {
// both tads have same shape, however strides and ews may differ
@ -695,7 +697,7 @@ namespace sd {
const Nd4jLong xLen = shape::length(xShapeInfo);
const Nd4jLong yLen = shape::length(yShapeInfo);
Nd4jLong* xTadShapeInfo = nullptr, * yTadShapeInfo = nullptr, * xTadOffsets = nullptr, * yTadOffsets = nullptr;
const Nd4jLong* xTadShapeInfo = nullptr, * yTadShapeInfo = nullptr, * xTadOffsets = nullptr, * yTadOffsets = nullptr;
TadPack tadPackX, tadPackY;
std::vector<Nd4jLong> zeroOffsets;
@ -962,12 +964,13 @@ namespace sd {
//////////////////////////////////////////////////////////////////////////////
template<typename X, typename Z>
template <typename OpType>
void sd::Reduction3Loops<X, Z>::loopReduce3All(X* x, Nd4jLong* xShapeInfo,
X* y, Nd4jLong* yShapeInfo,
Z* z, Nd4jLong* zShapeInfo,
Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets,
Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets,
Z* extraParameters, int64_t start, int64_t stop) {
void sd::Reduction3Loops<X, Z>::loopReduce3All(const X* x, const Nd4jLong* xShapeInfo,
const X* y, const Nd4jLong* yShapeInfo,
Z* z, const Nd4jLong* zShapeInfo,
const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets,
const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets,
Z* extraParameters,
int64_t start, int64_t stop) {
// both tads have same shape, however strides and ews may differ

View File

@ -35,28 +35,28 @@ namespace sd {
static std::vector<Nd4jLong> evalShapeForTensorDot(const NDArray* a, const NDArray* b, const std::vector<int>& axesA, const std::vector<int>& axesB, std::vector<int>& permutAt, std::vector<int>& permutBt, std::vector<Nd4jLong>& shapeAt, std::vector<Nd4jLong>& shapeBt);
// evaluate resulting shape after reduce operation
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const sd::DataType dataType, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr);
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const sd::DataType dataType, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr);
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr);
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr);
static const Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const sd::DataType dataType, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr);
static const Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const sd::DataType dataType, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr);
static const Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr);
static const Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr);
/**
* evaluate output shape for reduce operation when input shape is empty
* behavior is analogous to tf
*/
static Nd4jLong* evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimensions, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, sd::memory::Workspace* workspace);
static const Nd4jLong* evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimensions, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, sd::memory::Workspace* workspace);
// evaluate shape for array which is result of repeat operation applied to arr
static std::vector<Nd4jLong> evalRepeatShape(int axis, const std::vector<int>& repeats, const NDArray& arr);
// evaluate shapeInfo of permuted array
// if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order
static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides = false);
static Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace);
static const Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides = false);
static const Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace);
// evaluate shapeInfo of transposed array
// if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order
static Nd4jLong* evalTranspShapeInfo(const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides = false);
static const Nd4jLong* evalTranspShapeInfo(const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides = false);
static bool copyVectorPart(std::vector<int>& target, std::vector<int>& source, int rank, int offset);
@ -67,13 +67,13 @@ namespace sd {
// check whether 2 arrays have mutually broadcastable shapes
// shape comparison starts from the end
static bool areShapesBroadcastable(const NDArray &arr1, const NDArray &arr2);
static bool areShapesBroadcastable(Nd4jLong* shapeX, Nd4jLong* shapeY);
static bool areShapesBroadcastable(const Nd4jLong* shapeX, const Nd4jLong* shapeY);
static bool areShapesBroadcastable(const std::vector<Nd4jLong>& shape1, const std::vector<Nd4jLong>& shape2);
// check the possibility of broadcast operation, if true then return shapeInfo of resulting array
// if evalMinMax == false then array with larger rank has to be passed as first argument
static bool evalBroadcastShapeInfo(const NDArray& max, const NDArray& min, const bool evalMinMax, Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace);
static bool evalBroadcastShapeInfo(Nd4jLong *max, Nd4jLong *min, const bool evalMinMax, Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace);
static bool evalBroadcastShapeInfo(const NDArray& max, const NDArray& min, const bool evalMinMax, const Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace);
static bool evalBroadcastShapeInfo(const Nd4jLong *max, const Nd4jLong *min, const bool evalMinMax, const Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace);
// evaluate sorted vector of max axes to create tads along in case of simple broadcast operation
// if simple broadcast is not possible then empty vector is returned
@ -88,10 +88,10 @@ namespace sd {
static std::vector<int> getDimsWithSameShape(const NDArray& max, const NDArray& min);
// evaluate shapeInfo for resulting array of tile operation
static Nd4jLong* evalTileShapeInfo(const NDArray& arr, const std::vector<Nd4jLong>& reps, sd::memory::Workspace* workspace);
static const Nd4jLong* evalTileShapeInfo(const NDArray& arr, const std::vector<Nd4jLong>& reps, sd::memory::Workspace* workspace);
// returns shape part of shapeInfo as std::vector
static std::vector<Nd4jLong> pullShapeFromShapeInfo(Nd4jLong *shapeInfo);
static std::vector<Nd4jLong> pullShapeFromShapeInfo(const Nd4jLong *shapeInfo);
static std::string shapeAsString(const NDArray* array);
static std::string shapeAsString(const std::vector<Nd4jLong>& shape);
@ -104,13 +104,13 @@ namespace sd {
static std::vector<Nd4jLong> shapeAsVector(const Nd4jLong* shapeInfo);
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
static Nd4jLong* evalDiagShapeInfo(const Nd4jLong* shapeInfo, sd::memory::Workspace* workspace);
static const Nd4jLong* evalDiagShapeInfo(const Nd4jLong* shapeInfo, sd::memory::Workspace* workspace);
static std::vector<int> evalBroadcastBackwardAxis(const Nd4jLong *operand, const Nd4jLong *result);
// utility to calculate matrix product shape with give source shapes and additional params
// returns ShapeList pointer with result shape
static Nd4jLong* matrixProductShape(Nd4jLong* theFirstShape, Nd4jLong* theSecondShape, bool shouldTranspondFirst, bool shouldTranspondSecond, sd::DataType dtype, sd::memory::Workspace* workspace);
static const Nd4jLong* matrixProductShape(const Nd4jLong* theFirstShape, const Nd4jLong* theSecondShape, bool shouldTranspondFirst, bool shouldTranspondSecond, sd::DataType dtype, sd::memory::Workspace* workspace);
/**
* This method evaluates permutation vector necessary for reducing of shapeFrom to shapeTo

View File

@ -55,20 +55,20 @@ namespace shape {
Nd4jLong tadIndex = 0;
int dimensionLength;
int* dimension = nullptr;
Nd4jLong *shapeInfo = nullptr;
Nd4jLong *tadOnlyShapeInfo = nullptr;
Nd4jLong const* shapeInfo = nullptr;
Nd4jLong* tadOnlyShapeInfo = nullptr;
Nd4jLong numTads = 0;
int tadRank = 0;
Nd4jLong *tadShape = nullptr;
Nd4jLong *tadStride = nullptr;
Nd4jLong *tadOffsets = nullptr;
Nd4jLong* tadShape = nullptr;
Nd4jLong* tadStride = nullptr;
Nd4jLong* tadOffsets = nullptr;
Nd4jLong tadOffsetForBlock = 0;
int rank = 0;
int numOnes = 0;
//pointers to original
int originalDimensionLength;
int *originalDimension = nullptr;
Nd4jLong *originalShapeInfo = nullptr;
int const* originalDimension = nullptr;
Nd4jLong const* originalShapeInfo = nullptr;
bool squeezed = false;
bool newSqueezeDimensions = false;
int numOnesInMiddle = 0;
@ -81,7 +81,7 @@ namespace shape {
void *ptrManager = nullptr;
int *ptrOutput = nullptr;
INLINEDEF bool dimensionsDescending(int rank, int *dimensions, int length);
INLINEDEF bool dimensionsDescending(int rank, int const* dimensions, int length);
#ifdef __CUDACC__
__host__ __device__
@ -114,12 +114,12 @@ namespace shape {
#ifdef __CUDACC__
__host__ __device__
#endif
INLINEDEF void init(Nd4jLong *shapeInfo,int *dimension,int dimensionLength);
INLINEDEF void init(Nd4jLong const* shapeInfo,int const* dimension,int dimensionLength);
#ifdef __CUDACC__
__host__ __device__
#endif
INLINEDEF void init(int index, Nd4jLong *shapeInfo,int *dimension,int dimensionLength);
INLINEDEF void init(int index, Nd4jLong const* shapeInfo,int const* dimension,int dimensionLength);
@ -134,12 +134,12 @@ namespace shape {
#ifdef __CUDACC__
__host__ __device__
#endif
INLINEDEF void permuteShapeBufferInPlace(Nd4jLong *shapeBuffer, int* rearrange, Nd4jLong *out);
INLINEDEF void permuteShapeBufferInPlace(Nd4jLong const* shapeBuffer, int const* rearrange, Nd4jLong *out);
#ifdef __CUDACC__
__host__ __device__
#endif
INLINEDEF Nd4jLong* permuteShapeBuffer(Nd4jLong *shapeBuffer, int *rearrange);
INLINEDEF Nd4jLong* permuteShapeBuffer(Nd4jLong const* shapeBuffer, int *rearrange);
@ -153,7 +153,7 @@ namespace shape {
#ifdef __CUDACC__
__host__ __device__
#endif
INLINEDEF Nd4jLong lengthPerSlice(Nd4jLong *shapeBuffer);
INLINEDEF Nd4jLong lengthPerSlice(Nd4jLong const* shapeBuffer);
#ifdef __CUDACC__
@ -253,7 +253,7 @@ namespace shape {
#ifdef __CUDACC__
__host__ __device__
#endif
INLINEDEF Nd4jLong tadLength(Nd4jLong *shapeInfo, int *dimension, int dimensionLength);
INLINEDEF Nd4jLong tadLength(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength);
/**
* Computes the number
@ -263,7 +263,7 @@ namespace shape {
#ifdef __CUDACC__
__host__ __device__
#endif
INLINEDEF Nd4jLong tensorsAlongDimension(Nd4jLong *shapeInfo, int *dimension, int dimensionLength);
INLINEDEF Nd4jLong tensorsAlongDimension(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength);
#ifdef __CUDACC__
@ -337,19 +337,19 @@ namespace shape {
this->wholeThing = this->numTads == 1 || ((this->dimensionLength == this->rank || this->numTads == shape::length(this->shapeInfo)) && ews == 1);
}
INLINEDEF void TAD::init(int tadIndex, Nd4jLong *shapeInfo,int *dimension,int dimensionLength) {
INLINEDEF void TAD::init(int tadIndex, Nd4jLong const* shapeInfo,int const* dimension,int dimensionLength) {
this->tadIndex = tadIndex;
this->init(shapeInfo, dimension, dimensionLength);
}
INLINEDEF void TAD::init(Nd4jLong *shapeInfo, int *dimension,int dimensionLength) {
INLINEDEF void TAD::init(Nd4jLong const* shapeInfo, int const* dimension,int dimensionLength) {
this->originalShapeInfo = shapeInfo;
this->originalDimension = dimension;
this->originalDimensionLength = dimensionLength;
//start off as original references
this->shapeInfo = shapeInfo;
this->dimensionLength = dimensionLength;
this->dimension = dimension;
this->dimension = const_cast<int*>(dimension);
this->rank = shape::rank(shapeInfo);
this->numTads = dimensionLength == 0 ? 1 : this->tensorsAlongDimension(this->shapeInfo, this->dimension, this->dimensionLength);
@ -420,19 +420,19 @@ namespace shape {
}
INLINEDEF void TAD::permuteShapeBufferInPlace(Nd4jLong* shapeBuffer, int* rearrange, Nd4jLong* out) {
INLINEDEF void TAD::permuteShapeBufferInPlace(Nd4jLong const* shapeBuffer, int const* rearrange, Nd4jLong* out) {
memcpy(out, shapeBuffer, sizeof(Nd4jLong) * shape::shapeInfoLength(this->rank));
doPermuteShapeInfo(out, rearrange);
}
INLINEDEF Nd4jLong* TAD::permuteShapeBuffer(Nd4jLong* shapeBuffer, int *rearrange) {
INLINEDEF Nd4jLong* TAD::permuteShapeBuffer(Nd4jLong const* shapeBuffer, int *rearrange) {
int len = shape::shapeInfoLength(this->rank);
Nd4jLong *copy = shape::copyOf(len,shapeBuffer);
doPermuteShapeInfo(copy,rearrange);
return copy;
}
INLINEDEF bool TAD::dimensionsDescending(int rank, int *dimensions, int length) {
INLINEDEF bool TAD::dimensionsDescending(int rank, int const* dimensions, int length) {
int desired = rank - 1;
for (int e = length - 1; e >= 0; e--) {
if (dimensions[e] != desired--)
@ -465,7 +465,7 @@ namespace shape {
this->tadStride = shape::stride(this->tadOnlyShapeInfo);
}
INLINEDEF Nd4jLong TAD::lengthPerSlice(Nd4jLong* shapeBuffer) {
INLINEDEF Nd4jLong TAD::lengthPerSlice(Nd4jLong const* shapeBuffer) {
int dimension = 0;
Nd4jLong *remove = shape::removeIndex(shape::shapeOf(shapeBuffer),&dimension,shape::rank(shapeBuffer),1);
Nd4jLong prod = shape::prodLong(remove, shape::rank(shapeBuffer) - 1);
@ -635,7 +635,7 @@ namespace shape {
}
INLINEDEF Nd4jLong* TAD::tensorShape() {
INLINEDEF Nd4jLong* TAD::tensorShape(){
if(this->tadShape != nullptr)
return this->tadShape;
@ -902,7 +902,7 @@ namespace shape {
}
INLINEDEF Nd4jLong TAD::tadLength(Nd4jLong *shapeInfo, int *dimension, int dimensionLength) {
INLINEDEF Nd4jLong TAD::tadLength(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength) {
if(dimensionLength == 1) {
return shape::shapeOf(shapeInfo)[dimension[0]];
}
@ -919,7 +919,7 @@ namespace shape {
}
INLINEDEF Nd4jLong TAD::tensorsAlongDimension(Nd4jLong *shapeInfo, int *dimension, int dimensionLength) {
INLINEDEF Nd4jLong TAD::tensorsAlongDimension(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength) {
return shape::length(shapeInfo) / this->tadLength(shapeInfo,dimension,dimensionLength);
}

View File

@ -55,22 +55,16 @@ namespace sd {
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
int deviceId = 0;
_mutex.lock();
std::lock_guard<std::mutex> lock(_mutex);
if (_cache[deviceId].count(descriptor) == 0) {
auto hPtr = descriptor.toShapeInfo();
ConstantDataBuffer buffer(hPtr, nullptr, shape::shapeInfoLength(hPtr)*sizeof(Nd4jLong), DataType::INT64);
ShapeDescriptor descriptor1(descriptor);
_cache[deviceId][descriptor1] = buffer;
auto r = _cache[deviceId][descriptor1];
_mutex.unlock();
return r;
return _cache[deviceId][descriptor1];
} else {
auto r = _cache[deviceId].at(descriptor);
_mutex.unlock();
return r;
return _cache[deviceId].at(descriptor);
}
}
@ -82,52 +76,45 @@ namespace sd {
bool ConstantShapeHelper::checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor) {
bool result;
int deviceId = 0;
_mutex.lock();
std::lock_guard<std::mutex> lock(_mutex);
if (_cache[deviceId].count(descriptor) == 0)
result = false;
else
result = true;
_mutex.unlock();
return result;
return _cache[deviceId].count(descriptor) != 0;
}
Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
const Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
ShapeDescriptor descriptor(dataType, order, shape, rank);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const Nd4jLong* shapeInfo) {
const Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const Nd4jLong* shapeInfo) {
return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast<Nd4jLong*>(shapeInfo)));
}
Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const sd::DataType dataType) {
const Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const sd::DataType dataType) {
auto descriptor = ShapeDescriptor::emptyDescriptor(dataType);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
Nd4jLong* ConstantShapeHelper::scalarShapeInfo(const sd::DataType dataType) {
const Nd4jLong* ConstantShapeHelper::scalarShapeInfo(const sd::DataType dataType) {
auto descriptor = ShapeDescriptor::scalarDescriptor(dataType);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
Nd4jLong* ConstantShapeHelper::vectorShapeInfo(const Nd4jLong length, const sd::DataType dataType) {
const Nd4jLong* ConstantShapeHelper::vectorShapeInfo(const Nd4jLong length, const sd::DataType dataType) {
auto descriptor = ShapeDescriptor::vectorDescriptor(length, dataType);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const std::vector<Nd4jLong> &shape) {
const Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const std::vector<Nd4jLong> &shape) {
ShapeDescriptor descriptor(dataType, order, shape);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
Nd4jLong* ConstantShapeHelper::createShapeInfo(const ShapeDescriptor &descriptor) {
const Nd4jLong* ConstantShapeHelper::createShapeInfo(const ShapeDescriptor &descriptor) {
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal) {
const Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal) {
ShapeDescriptor descriptor(shapeInfo);
auto result = createShapeInfo(descriptor);
@ -137,7 +124,7 @@ namespace sd {
return result;
}
Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace) {
const Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace) {
ShapeDescriptor descriptor(shapeInfo);
auto result = createShapeInfo(descriptor);
@ -148,7 +135,7 @@ namespace sd {
////////////////////////////////////////////////////////////////////////
ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector<int> dimensions) {
ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector<int> &dimensions) {
Nd4jLong* newShapeInfo = nullptr;
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong);

View File

@ -44,9 +44,9 @@ static void usualGemm(const NDArray* vA, const NDArray* vB, NDArray* vC,
const bool betaPersent = beta;
const Nd4jLong* aShapeInfo = vA->getShapeInfo();
const Nd4jLong* bShapeInfo = vB->getShapeInfo();
const Nd4jLong* cShapeInfo = vC->getShapeInfo();
const Nd4jLong* aShapeInfo = vA->shapeInfo();
const Nd4jLong* bShapeInfo = vB->shapeInfo();
const Nd4jLong* cShapeInfo = vC->shapeInfo();
const int aRank = vA->rankOf();
const int bRank = vB->rankOf();
@ -111,9 +111,9 @@ static void usualGemv(const NDArray* vA, const NDArray* vX, NDArray* vY, const
const bool betaPersent = beta;
const Nd4jLong* aShapeInfo = vA->getShapeInfo();
const Nd4jLong* xShapeInfo = vX->getShapeInfo();
const Nd4jLong* yShapeInfo = vY->getShapeInfo();
const Nd4jLong* aShapeInfo = vA->shapeInfo();
const Nd4jLong* xShapeInfo = vX->shapeInfo();
const Nd4jLong* yShapeInfo = vY->shapeInfo();
const int N = vX->lengthOf();
const int M = vY->lengthOf();
@ -294,13 +294,13 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y,
if(A->rankOf() != 2)
throw std::runtime_error("MmulHelper::mmulMxV: rank of A array is not equal 2 !");
if(!shape::isCommonVector(X->getShapeInfo(), xLenDim))
if(!shape::isCommonVector(X->shapeInfo(), xLenDim))
throw std::runtime_error("MmulHelper::mmulMxV: X array must be vector !");
const auto M = A->sizeAt(0);
const auto N = A->sizeAt(1);
if(Y != nullptr && !shape::isCommonVector(Y->getShapeInfo(), yLenDim))
if(Y != nullptr && !shape::isCommonVector(Y->shapeInfo(), yLenDim))
throw std::runtime_error("MmulHelper::mmulMxV: Y array must be vector !");
if(X->lengthOf() != N)
throw std::runtime_error("MmulHelper::mmulMxV: X vector has wrong length !");
@ -347,10 +347,10 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y,
// choose appropriate cuda gemm api depending on data types
if(typeDouble) {
BlasHelper::getInstance()->dgemv()(blasOrder, CblasNoTrans, M, N, alpha, (double*)pA->getBuffer(), lda, (double*)X->getBuffer(), incx, beta, (double*)Y->getBuffer(), incy);
BlasHelper::getInstance()->dgemv()(blasOrder, CblasNoTrans, M, N, alpha, (double*)pA->buffer(), lda, (double*)X->buffer(), incx, beta, (double*)Y->buffer(), incy);
}
else if(typeFloat) {
BlasHelper::getInstance()->sgemv()(blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->getBuffer(), lda, (float*)X->getBuffer(), incx, (float)beta, (float*)Y->getBuffer(), incy);
BlasHelper::getInstance()->sgemv()(blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->buffer(), lda, (float*)X->buffer(), incx, (float)beta, (float*)Y->buffer(), incy);
}
if(pA != A)
@ -371,9 +371,9 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con
int xLenDim(0), yLenDim(0);
if(!shape::isCommonVector(X->getShapeInfo(), xLenDim))
if(!shape::isCommonVector(X->shapeInfo(), xLenDim))
throw std::runtime_error("MmulHelper::dot: X array must be vector !");
if(!shape::isCommonVector(Y->getShapeInfo(), yLenDim))
if(!shape::isCommonVector(Y->shapeInfo(), yLenDim))
throw std::runtime_error("MmulHelper::dot: Y array must be vector !");
if(Z != nullptr && !Z->isScalar())
throw std::runtime_error("MmulHelper::dot: Z array must be scalar !");
@ -393,8 +393,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con
const auto yType = Y->dataType();
const auto zType = Z->dataType();
BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), NUMERIC_TYPES);
//BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (length, alpha, X->buffer(), incx, Y->buffer(), incy, beta, Z->buffer()), NUMERIC_TYPES);
//BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, X->buffer(), incx, Y->buffer(), incy, beta, Z->buffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
return Z;
}
@ -419,9 +419,9 @@ static void batchedGemm(const NDArray* vA, const NDArray* vB, NDArray* vC,
const bool betaPersent = beta;
const Nd4jLong* aShapeInfo = vA->getShapeInfo();
const Nd4jLong* bShapeInfo = vB->getShapeInfo();
const Nd4jLong* cShapeInfo = vC->getShapeInfo();
const Nd4jLong* aShapeInfo = vA->shapeInfo();
const Nd4jLong* bShapeInfo = vB->shapeInfo();
const Nd4jLong* cShapeInfo = vC->shapeInfo();
const int aRank = vA->rankOf();
const int bRank = vB->rankOf();
@ -576,13 +576,13 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con
// multiplication
const std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1});
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->getShapeInfo(), dimsToExclude);
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->shapeInfo(), dimsToExclude);
std::vector<Nd4jLong> idxRanges(2 * C->rankOf());
// #pragma omp parallel for schedule(guided) firstprivate(idxRanges)
for(Nd4jLong i = 0; i < numOfSubArrs; ++i) {
ShapeUtils::evalIdxRangesForSubArr(i, C->getShapeInfo(), dimsToExclude, idxRanges.data());
ShapeUtils::evalIdxRangesForSubArr(i, C->shapeInfo(), dimsToExclude, idxRanges.data());
NDArray cSubArr = (*C)(idxRanges);
if(aRank > bRank) {

View File

@ -26,10 +26,10 @@ using namespace simdOps;
//////////////////////////////////////////////////////////////////////////////
template <typename X, typename Z>
template <typename OpType>
void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
Z* z, Nd4jLong* zShapeInfo,
Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets,
X* extraParams) {
void sd::IndexReductionLoops<X,Z>::loopIndexReduce(const X* x, const Nd4jLong* xShapeInfo,
Z* z, const Nd4jLong* zShapeInfo,
const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets,
X* extraParams) {
sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(xShapeInfo, zShapeInfo, tadShapeInfo);
if(kindOfLoop == sd::LoopKind::SMALLARR2DX)
@ -305,8 +305,8 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
}
template <typename X, typename Y>
void sd::IndexReductionLoops<X, Y>::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* vz, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams) {
auto x = reinterpret_cast<X *>(vx);
void sd::IndexReductionLoops<X, Y>::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams) {
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Y *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams);

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_0, (sd::DataType::INT32, int32_t));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_0, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_1, (sd::DataType::INT32, int32_t));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_1, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_2, (sd::DataType::INT32, int32_t));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_2, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_3, (sd::DataType::INT32, int32_t));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_3, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_4, (sd::DataType::INT32, int32_t));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_4, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_5, (sd::DataType::INT32, int32_t));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_5, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_6, (sd::DataType::INT32, int32_t));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_6, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_7, (sd::DataType::INT32, int32_t));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_7, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_8, (sd::DataType::INT32, int32_t));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_8, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_9, (sd::DataType::INT32, int32_t));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_9, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_0, (sd::DataType::INT64, Nd4jLong));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_0, (sd::DataType::INT64, Nd4jLong));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_1, (sd::DataType::INT64, Nd4jLong));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_1, (sd::DataType::INT64, Nd4jLong));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_2, (sd::DataType::INT64, Nd4jLong));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_2, (sd::DataType::INT64, Nd4jLong));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_3, (sd::DataType::INT64, Nd4jLong));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_3, (sd::DataType::INT64, Nd4jLong));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_4, (sd::DataType::INT64, Nd4jLong));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_4, (sd::DataType::INT64, Nd4jLong));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_5, (sd::DataType::INT64, Nd4jLong));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_5, (sd::DataType::INT64, Nd4jLong));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_6, (sd::DataType::INT64, Nd4jLong));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_6, (sd::DataType::INT64, Nd4jLong));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_7, (sd::DataType::INT64, Nd4jLong));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_7, (sd::DataType::INT64, Nd4jLong));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_8, (sd::DataType::INT64, Nd4jLong));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_8, (sd::DataType::INT64, Nd4jLong));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_9, (sd::DataType::INT64, Nd4jLong));
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_9, (sd::DataType::INT64, Nd4jLong));

View File

@ -28,7 +28,7 @@ namespace sd {
template<typename X, typename Z>
template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
void Reduction3Loops<X,Z>::innerloopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop);
#endif
@ -36,21 +36,21 @@ namespace sd {
template<typename X, typename Z>
template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
void Reduction3Loops<X,Z>::innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop);
#endif
}
template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
void Reduction3Loops<X, Y>::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS);
#endif
}
template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS);
#endif

View File

@ -28,7 +28,7 @@ namespace sd {
template<typename X, typename Z>
template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
void Reduction3Loops<X,Z>::innerloopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop);
#endif
@ -36,21 +36,21 @@ namespace sd {
template<typename X, typename Z>
template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
void Reduction3Loops<X,Z>::innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop);
#endif
}
template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
void Reduction3Loops<X, Y>::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS);
#endif
}
template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS);
#endif

View File

@ -28,7 +28,7 @@ namespace sd {
template<typename X, typename Z>
template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
void Reduction3Loops<X,Z>::innerloopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop);
#endif
@ -36,21 +36,21 @@ namespace sd {
template<typename X, typename Z>
template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
void Reduction3Loops<X,Z>::innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop);
#endif
}
template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
void Reduction3Loops<X, Y>::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS);
#endif
}
template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS);
#endif

View File

@ -28,7 +28,7 @@ namespace sd {
template<typename X, typename Z>
template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
void Reduction3Loops<X,Z>::innerloopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop);
#endif
@ -36,21 +36,21 @@ namespace sd {
template<typename X, typename Z>
template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
void Reduction3Loops<X,Z>::innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop);
#endif
}
template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
void Reduction3Loops<X, Y>::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS);
#endif
}
template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS);
#endif

View File

@ -26,17 +26,18 @@ namespace sd {
template<typename X, typename Z>
template <typename OpType>
void ReductionBoolLoops<X, Z>::innerloopReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
void ReductionBoolLoops<X, Z>::innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif
}
template<typename X, typename Y>
void ReductionBoolLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets,
X *extraParams, int64_t start, int64_t stop) {
void ReductionBoolLoops<X, Y>::wrapper(const int opNum,
const X *x, const Nd4jLong *xShapeInfo,
Y *z, const Nd4jLong *zShapeInfo,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
X *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_BOOL_OPS);
#endif

View File

@ -28,16 +28,18 @@ namespace sd {
template<typename X, typename Z>
template <typename OpType>
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
void ReductionFloatLoops<X, Z>::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif
}
template<typename X, typename Y>
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) {
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo,
Y *z, const Nd4jLong *zShapeInfo,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
Y *extraParams,
int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS);
#endif

View File

@ -28,16 +28,19 @@ namespace sd {
template<typename X, typename Z>
template <typename OpType>
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
void ReductionFloatLoops<X, Z>::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif
}
template<typename X, typename Y>
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) {
void ReductionFloatLoops<X, Y>::wrapper(const int opNum,
const X *x, const Nd4jLong *xShapeInfo,
Y *z, const Nd4jLong *zShapeInfo,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
Y *extraParams,
int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS);
#endif

View File

@ -28,16 +28,16 @@ namespace sd {
template<typename X, typename Z>
template <typename OpType>
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
void ReductionFloatLoops<X, Z>::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif
}
template<typename X, typename Y>
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) {
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, Y *z,
const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo,
const Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS);
#endif

View File

@ -28,16 +28,16 @@ namespace sd {
template<typename X, typename Z>
template <typename OpType>
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
void ReductionFloatLoops<X, Z>::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif
}
template<typename X, typename Y>
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) {
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, Y *z,
const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo,
const Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS);
#endif

View File

@ -33,16 +33,16 @@ namespace sd {
template<typename X, typename Z>
template <typename OpType>
void ReductionLongLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z *z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
void ReductionLongLoops<X, Z>::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z *z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif
}
template<typename X, typename Y>
void ReductionLongLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets, X *extraParams, int64_t start, int64_t stop) {
void ReductionLongLoops<X, Y>::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, Y *z,
const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo,
const Nd4jLong *tadOffsets, X *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_LONG_OPS);
#endif

View File

@ -26,16 +26,16 @@ namespace sd {
template<typename X>
template <typename OpType>
void ReductionSameLoops<X>::innerloopReduce(X* x, Nd4jLong* xShapeInfo, X* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
void ReductionSameLoops<X>::innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
ReductionLoops<X,X,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif
}
template<typename X>
void ReductionSameLoops<X>::wrapper(const int opNum, X *vx, Nd4jLong *xShapeInfo, X *vz,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets,
void ReductionSameLoops<X>::wrapper(const int opNum, const X *vx, const Nd4jLong *xShapeInfo, X *vz,
const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo,
const Nd4jLong *tadOffsets,
X *vextraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
auto x = reinterpret_cast<X *>(vx);

View File

@ -83,40 +83,40 @@ namespace sd {
return _cache[deviceId].count(descriptor) != 0;
}
Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
Nd4jLong const* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
ShapeDescriptor descriptor(dataType, order, shape, rank);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const Nd4jLong* shapeInfo) {
Nd4jLong const* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const Nd4jLong* shapeInfo) {
return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast<Nd4jLong*>(shapeInfo)));
}
Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const sd::DataType dataType) {
Nd4jLong const* ConstantShapeHelper::emptyShapeInfo(const sd::DataType dataType) {
auto descriptor = ShapeDescriptor::emptyDescriptor(dataType);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
Nd4jLong* ConstantShapeHelper::scalarShapeInfo(const sd::DataType dataType) {
Nd4jLong const* ConstantShapeHelper::scalarShapeInfo(const sd::DataType dataType) {
auto descriptor = ShapeDescriptor::scalarDescriptor(dataType);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
Nd4jLong* ConstantShapeHelper::vectorShapeInfo(const Nd4jLong length, const sd::DataType dataType) {
Nd4jLong const* ConstantShapeHelper::vectorShapeInfo(const Nd4jLong length, const sd::DataType dataType) {
auto descriptor = ShapeDescriptor::vectorDescriptor(length, dataType);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const std::vector<Nd4jLong> &shape) {
Nd4jLong const* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const std::vector<Nd4jLong> &shape) {
ShapeDescriptor descriptor(dataType, order, shape);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
Nd4jLong* ConstantShapeHelper::createShapeInfo(const ShapeDescriptor &descriptor) {
Nd4jLong const* ConstantShapeHelper::createShapeInfo(const ShapeDescriptor &descriptor) {
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal) {
Nd4jLong const* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal) {
ShapeDescriptor descriptor(shapeInfo);
auto result = createShapeInfo(descriptor);
@ -126,7 +126,7 @@ namespace sd {
return result;
}
Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace) {
Nd4jLong const* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace) {
ShapeDescriptor descriptor(shapeInfo);
auto result = createShapeInfo(descriptor);
@ -136,7 +136,7 @@ namespace sd {
}
////////////////////////////////////////////////////////////////////////
ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector<int> dimensions) {
ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector<int>& dimensions) {
Nd4jLong* newShapeInfo = nullptr;
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong);

View File

@ -268,8 +268,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
const int sharedMem = threadsPerBlock * sizeof(int) * 6 + 128; // 6 = aRank + bRank + cRank
NDArray::prepareSpecialUse({C}, {A, B});
// BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES)
// BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->specialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->specialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES)
NDArray::registerSpecialUse({C}, {A, B});
auto cudaResult = cudaStreamSynchronize(*stream);
@ -319,23 +319,23 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
// choose appropriate cuda gemm api depending on data types
if(typeDouble) {
status = cublasDgemm(*handle, transAblas, transBblas, M, N, K, &alpha, (double*)pA->getSpecialBuffer(), lda, (double*)pB->getSpecialBuffer(), ldb, &beta, (double*)pC->getSpecialBuffer(), ldc);
status = cublasDgemm(*handle, transAblas, transBblas, M, N, K, &alpha, (double*)pA->specialBuffer(), lda, (double*)pB->specialBuffer(), ldb, &beta, (double*)pC->specialBuffer(), ldc);
}
else if(typeFloat) {
float alphaF(alpha), betaF(beta);
status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc);
status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->specialBuffer(), lda, (float*)pB->specialBuffer(), ldb, &betaF, (float*)pC->specialBuffer(), ldc);
}
else if(typeHalf) {
float16 alphaH(alpha), betaH(beta);
status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc);
status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->specialBuffer(), lda, (__half*)pB->specialBuffer(), ldb, &betaH.data, (__half*)pC->specialBuffer(), ldc);
}
else if(typeIntFloat) {
float alphaF(alpha), betaF(beta);
status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_8I, lda, pB->getSpecialBuffer(), CUDA_R_8I, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc);
status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->specialBuffer(), CUDA_R_8I, lda, pB->specialBuffer(), CUDA_R_8I, ldb, &betaF, pC->specialBuffer(), CUDA_R_32F, ldc);
}
else if(typeHalfFloat) {
float alphaF(alpha), betaF(beta);
status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_16F, lda, pB->getSpecialBuffer(), CUDA_R_16F, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc);
status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->specialBuffer(), CUDA_R_16F, lda, pB->specialBuffer(), CUDA_R_16F, ldb, &betaF, pC->specialBuffer(), CUDA_R_32F, ldc);
}
if (status != CUBLAS_STATUS_SUCCESS)
@ -365,13 +365,13 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y,
if(A->rankOf() != 2)
throw std::runtime_error("MmulHelper::mmulMxV cuda: rank of A array is not equal 2 !");
if(!shape::isCommonVector(X->getShapeInfo(), xLenDim))
if(!shape::isCommonVector(X->shapeInfo(), xLenDim))
throw std::runtime_error("MmulHelper::mmulMxV cuda: X array must be vector !");
const auto M = A->sizeAt(0);
const auto N = A->sizeAt(1);
if(Y != nullptr && !shape::isCommonVector(Y->getShapeInfo(), yLenDim))
if(Y != nullptr && !shape::isCommonVector(Y->shapeInfo(), yLenDim))
throw std::runtime_error("MmulHelper::mmulMxV cuda: Y array must be vector !");
if(X->lengthOf() != N)
throw std::runtime_error("MmulHelper::mmulMxV cuda: X vector has wrong length !");
@ -411,8 +411,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y,
const int blocksPerGrid = (M + threadsPerBlock - 1) / threadsPerBlock;
NDArray::prepareSpecialUse({Y}, {A, X});
// BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES)
// BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->specialBuffer(), A->specialShapeInfo(), X->specialBuffer(), X->specialShapeInfo(), Y->specialBuffer(), Y->specialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->specialBuffer(), A->specialShapeInfo(), X->specialBuffer(), X->specialShapeInfo(), Y->specialBuffer(), Y->specialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES)
NDArray::registerSpecialUse({Y}, {A, X});
auto cudaResult = cudaStreamSynchronize(*stream);
@ -442,11 +442,11 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y,
// choose appropriate cuda gemm api depending on data types
if(typeDouble) {
status = cublasDgemv(*handle, transAblas, transA ? N : M, transA ? M : N, &alpha, (double*)pA->getSpecialBuffer(), lda, (double*)X->getSpecialBuffer(), incx, &beta, (double*)Y->getSpecialBuffer(), incy);
status = cublasDgemv(*handle, transAblas, transA ? N : M, transA ? M : N, &alpha, (double*)pA->specialBuffer(), lda, (double*)X->specialBuffer(), incx, &beta, (double*)Y->specialBuffer(), incy);
}
else if(typeFloat) {
float alphaF(alpha), betaF(beta);
status = cublasSgemv(*handle, transAblas, transA ? N : M, transA ? M : N, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)X->getSpecialBuffer(), incx, &betaF, (float*)Y->getSpecialBuffer(), incy);
status = cublasSgemv(*handle, transAblas, transA ? N : M, transA ? M : N, &alphaF, (float*)pA->specialBuffer(), lda, (float*)X->specialBuffer(), incx, &betaF, (float*)Y->specialBuffer(), incy);
}
if (status != CUBLAS_STATUS_SUCCESS)
@ -471,9 +471,9 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con
int xLenDim(0), yLenDim(0);
if(!shape::isCommonVector(X->getShapeInfo(), xLenDim))
if(!shape::isCommonVector(X->shapeInfo(), xLenDim))
throw std::runtime_error("MmulHelper::dot cuda: X array must be vector !");
if(!shape::isCommonVector(Y->getShapeInfo(), yLenDim))
if(!shape::isCommonVector(Y->shapeInfo(), yLenDim))
throw std::runtime_error("MmulHelper::dot cuda: Y array must be vector !");
if(Z != nullptr && !Z->isScalar())
throw std::runtime_error("MmulHelper::dot cuda: Z array must be scalar !");
@ -506,8 +506,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con
NDArray::prepareSpecialUse({Z}, {X, Y});
//BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES)
//BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->specialBuffer(), incx, Y->specialBuffer(), incy, beta, Z->specialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->specialBuffer(), incx, Y->specialBuffer(), incy, beta, Z->specialBuffer()), NUMERIC_TYPES)
auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult);
@ -667,8 +667,8 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con
cBatchDims = reinterpret_cast<int*>(manager.replicatePointer(ShapeUtils::evalDimsToExclude(cRank, {cMaxis, cNaxis}).data(), (cRank - 2) * sizeof(int)));
NDArray::prepareSpecialUse({C}, {A, B});
// BUILD_TRIPLE_SELECTOR(A->dataType(), b->dataType(), C->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(A->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, sharedMem, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), aBatchDims, bBatchDims, cBatchDims, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES)
// BUILD_TRIPLE_SELECTOR(A->dataType(), b->dataType(), C->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, A->getContext()->getCudaStream(), A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->specialShapeInfo(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(A->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, sharedMem, A->getContext()->getCudaStream(), A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->specialShapeInfo(), aBatchDims, bBatchDims, cBatchDims, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES)
NDArray::registerSpecialUse({C}, {A, B});
manager.synchronize();
@ -797,13 +797,13 @@ NDArray* MmulHelper::mmulNxNold1(const NDArray* A, const NDArray* B, NDArray* C,
// multiplication
const std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1});
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->getShapeInfo(), dimsToExclude);
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->shapeInfo(), dimsToExclude);
std::vector<Nd4jLong> idxRanges(2 * C->rankOf());
// #pragma omp parallel for schedule(guided) firstprivate(idxRanges)
for(Nd4jLong i = 0; i < numOfSubArrs; ++i) {
ShapeUtils::evalIdxRangesForSubArr(i, C->getShapeInfo(), dimsToExclude, idxRanges.data());
ShapeUtils::evalIdxRangesForSubArr(i, C->shapeInfo(), dimsToExclude, idxRanges.data());
NDArray cSubArr = (*C)(idxRanges);
if(aRank > bRank) {
@ -944,18 +944,18 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C,
std::vector<void*> aSubArrs(bS), bSubArrs(bS), cSubArrs(bS);
if(aRank > 2)
shape::calcSubArrsShapeInfoAndOffsets(pA->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
shape::calcSubArrsShapeInfoAndOffsets(pA->shapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
for (int i = 0; i < bS; ++i)
aSubArrs[i] = aRank == 2 ? pA->getSpecialBuffer() : pA->getSpecialBuffer() + subArrOffsets[i] * pA->sizeOfT();
aSubArrs[i] = aRank == 2 ? pA->specialBuffer() : pA->specialBuffer() + subArrOffsets[i] * pA->sizeOfT();
if(bRank > 2)
shape::calcSubArrsShapeInfoAndOffsets(pB->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
shape::calcSubArrsShapeInfoAndOffsets(pB->shapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
for (int i = 0; i < bS; ++i)
bSubArrs[i] = bRank == 2 ? pB->getSpecialBuffer() : pB->getSpecialBuffer() + subArrOffsets[i] * pB->sizeOfT();
bSubArrs[i] = bRank == 2 ? pB->specialBuffer() : pB->specialBuffer() + subArrOffsets[i] * pB->sizeOfT();
shape::calcSubArrsShapeInfoAndOffsets(pC->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
shape::calcSubArrsShapeInfoAndOffsets(pC->shapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
for (int i = 0; i < bS; ++i)
cSubArrs[i] = pC->getSpecialBuffer() + subArrOffsets[i] * pC->sizeOfT();
cSubArrs[i] = pC->specialBuffer() + subArrOffsets[i] * pC->sizeOfT();
PointersManager manager(A->getContext(), "mmulNxN");
@ -1011,7 +1011,7 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C,
for(Nd4jLong i = 0; i < bS; ++i) {
ShapeUtils::evalIdxRangesForSubArr(i, pC->getShapeInfo(), dimsToExclude, idxRanges.data());
ShapeUtils::evalIdxRangesForSubArr(i, pC->shapeInfo(), dimsToExclude, idxRanges.data());
NDArray cSubArr = (*pC)(idxRanges);
if(aRank > bRank) {

View File

@ -91,7 +91,7 @@ void sd::MmulHelper::tensorDot(const sd::NDArray* a, const sd::NDArray* b, sd::N
mmul(aPR, bPR, cPR, 1.0, 0.0);
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
if(cPR->buffer() != cP->buffer() || cPR->specialBuffer() != cP->specialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->buffer()
cP->assign(cPR);
if(aP != aPR)
@ -150,7 +150,7 @@ void sd::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c, c
// check whether new buffer allocation was happened for c array
if(!whatToDoWithC.empty()) {
for(int i = cArrs.size()-1; i > 0; --i) {
if(cArrs[i]->getBuffer() != cArrs[i-1]->getBuffer() || cArrs[i]->getSpecialBuffer() != cArrs[i-1]->getSpecialBuffer())
if(cArrs[i]->buffer() != cArrs[i-1]->buffer() || cArrs[i]->specialBuffer() != cArrs[i-1]->specialBuffer())
cArrs[i-1]->assign(cArrs[i]);
delete cArrs[i];
}
@ -203,8 +203,8 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
int lenDim;
const int aRank = A->rankOf();
const int bRank = B->rankOf();
const bool isAVector = shape::isCommonVector(A->getShapeInfo(), lenDim);
const bool isBVector = shape::isCommonVector(B->getShapeInfo(), lenDim);
const bool isAVector = shape::isCommonVector(A->shapeInfo(), lenDim);
const bool isBVector = shape::isCommonVector(B->shapeInfo(), lenDim);
// dot product of 2 vectors
if(isAVector && isBVector && (aRank != 2 || aRank == 2 && (A->isSameShape(B) || bRank == 1 && A->sizeAt(1) == 1))) // (1x1x1 * 1x1) or (1x4 * 1*4) or (4x1 * 4x1) or (4x1 * 4)
@ -243,7 +243,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
int xRank = x->rankOf();
int yRank = y->rankOf();
auto outShape = ShapeUtils::evalShapeForMatmul(x->getShapeInfo(), y->getShapeInfo(), transX, transY);
auto outShape = ShapeUtils::evalShapeForMatmul(x->shapeInfo(), y->shapeInfo(), transX, transY);
if(!z->isSameShape(outShape)) {
nd4j_printf("NDArrayFactory::matmul static method: input shape of output array is wrong, actual is %s and expected is %s ! \n", ShapeUtils::shapeAsString(z).c_str(), ShapeUtils::shapeAsString(outShape).c_str());
throw std::invalid_argument("");
@ -285,7 +285,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
for(int i = 0; i < batchRank; ++i)
dimsToExclude[i] = i;
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(xT->getShapeInfo(), dimsToExclude);
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(xT->shapeInfo(), dimsToExclude);
//PRAGMA_OMP_PARALLEL_FOR
for(Nd4jLong i = 0; i < numOfSubArrs; ++i) {

View File

@ -118,13 +118,13 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn
//////////////////////////////////////////////////////////////////////////
std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const NDArray* a, const NDArray* b, const std::vector<int>& axesA, const std::vector<int>& axesB, std::vector<int>& permutAt, std::vector<int>& permutBt, std::vector<Nd4jLong>& shapeAt, std::vector<Nd4jLong>& shapeBt) {
return evalShapeForTensorDot(a->getShapeInfo(), b->getShapeInfo(), axesA, axesB, permutAt, permutBt, shapeAt, shapeBt);
return evalShapeForTensorDot(a->shapeInfo(), b->shapeInfo(), axesA, axesB, permutAt, permutBt, shapeAt, shapeBt);
}
//////////////////////////////////////////////////////////////////////////
// evaluate output shape for reduce operation when input shape is empty
Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimsToExclude, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, sd::memory::Workspace* workspace) {
const Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimsToExclude, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, sd::memory::Workspace* workspace) {
if (dimsToExclude.size() == 0) { // return copy of input shape
Nd4jLong* outShapeInfo = ShapeBuilders::copyShapeInfoAndType(shapeInfo, dataType, true, workspace);
@ -171,22 +171,22 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty(const char order, std::vector<int
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimsToExclude, arr, arr.dataType(), keepDims, supportOldShapes, workspace);
}
const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimsToExclude, arr, arr.dataType(), keepDims, supportOldShapes, workspace);
}
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimsToExclude, shapeInfo, ArrayOptions::dataType(shapeInfo), keepDims, supportOldShapes, workspace);
}
const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimsToExclude, shapeInfo, ArrayOptions::dataType(shapeInfo), keepDims, supportOldShapes, workspace);
}
//////////////////////////////////////////////////////////////////////////
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const sd::DataType dataType, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimsToExclude, arr.getShapeInfo(), dataType, keepDims, supportOldShapes, workspace);
}
const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const sd::DataType dataType, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimsToExclude, arr.shapeInfo(), dataType, keepDims, supportOldShapes, workspace);
}
//////////////////////////////////////////////////////////////////////////
// evaluate shape resulting from reduce operation
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) {
const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) {
if(ArrayOptions::arrayType(shapeInfo) == ArrayType::EMPTY)
return ShapeUtils::evalReduceShapeInfoEmpty(order, dimsToExclude, shapeInfo, dataType, keepDims, workspace);
@ -314,39 +314,39 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of permuted array
Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides) {
const Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides) {
if (!arr.nonNull())
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!");
if (!arr.nonNull())
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!");
if (rank != arr.rankOf())
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!");
if (rank != arr.rankOf())
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!");
auto shapeInfoLength = shape::shapeInfoLength(rank);
auto shapeInfoLength = shape::shapeInfoLength(rank);
// allocate memory for new array - shapeInfo
Nd4jLong *shapeInfoNew = nullptr;
ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong);
// allocate memory for new array - shapeInfo
Nd4jLong *shapeInfoNew = nullptr;
ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong);
// copy arr _shapeInfo into new array
memcpy(shapeInfoNew, arr.getShapeInfo(), shape::shapeInfoByteLength(rank));
// copy arr _shapeInfo into new array
memcpy(shapeInfoNew, arr.shapeInfo(), shape::shapeInfoByteLength(rank));
// perform buffer permutation
shape::doPermuteShapeInfo(shapeInfoNew, dimensions, arr.lengthOf());
// perform buffer permutation
shape::doPermuteShapeInfo(shapeInfoNew, dimensions, arr.lengthOf());
if(setContigStrides)
shape::updateStrides(shapeInfoNew, arr.ordering());
if(setContigStrides)
shape::updateStrides(shapeInfoNew, arr.ordering());
ShapeDescriptor descriptor(shapeInfoNew);
ShapeDescriptor descriptor(shapeInfoNew);
RELEASE(shapeInfoNew, workspace);
RELEASE(shapeInfoNew, workspace);
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of permuted array
Nd4jLong* ShapeUtils::evalPermShapeInfo(const Nd4jLong *dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace) {
const Nd4jLong* ShapeUtils::evalPermShapeInfo(const Nd4jLong *dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace) {
std::vector<int> dims(dimensions, dimensions + rank);
return evalPermShapeInfo(dims.data(), rank, arr, workspace);
@ -354,7 +354,7 @@ Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, c
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of transposed array
Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides) {
const Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides) {
int rank = arr.rankOf();
std::vector<int> dimensions(rank);
@ -414,10 +414,10 @@ std::vector<int> ShapeUtils::evalDimsToExclude(const int rank, const std::vector
// check whether 2 arrays have mutually broadcastable shapes
// shape comparison starts from the end
bool ShapeUtils::areShapesBroadcastable(const NDArray &arr1, const NDArray &arr2) {
return areShapesBroadcastable(arr1.getShapeInfo(), arr2.getShapeInfo());
return areShapesBroadcastable(arr1.shapeInfo(), arr2.shapeInfo());
}
bool ShapeUtils::areShapesBroadcastable(Nd4jLong *shapeInfo1, Nd4jLong *shapeInfo2) {
bool ShapeUtils::areShapesBroadcastable(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2) {
int minRank = shape::rank(shapeInfo1) < shape::rank(shapeInfo2) ? shape::rank(shapeInfo1) : shape::rank(shapeInfo2);
for (int i = -1; i >= -minRank; --i)
@ -427,177 +427,177 @@ bool ShapeUtils::areShapesBroadcastable(Nd4jLong *shapeInfo1, Nd4jLong *shapeInf
return true;
}
bool ShapeUtils::areShapesBroadcastable(const std::vector<Nd4jLong>& shape1, const std::vector<Nd4jLong>& shape2) {
bool ShapeUtils::areShapesBroadcastable(const std::vector<Nd4jLong>& shape1, const std::vector<Nd4jLong>& shape2) {
const auto rank1 = shape1.size();
const auto rank2 = shape2.size();
const int minRank = rank1 < rank2 ? rank1 : rank2;
const auto rank1 = shape1.size();
const auto rank2 = shape2.size();
const int minRank = rank1 < rank2 ? rank1 : rank2;
for (int i = 1; i <= minRank; ++i)
if (shape1[rank1-i] != shape2[rank2-i] && shape1[rank1-i] != 1 && shape2[rank2-i] != 1)
for (int i = 1; i <= minRank; ++i)
if (shape1[rank1-i] != shape2[rank2-i] && shape1[rank1-i] != 1 && shape2[rank2-i] != 1)
return false;
return true;
}
//////////////////////////////////////////////////////////////////////////
// check the possibility of broadcast operation, if true then return shapeInfo of resulting array
// if evalMinMax == false the array with larger rank has to be passed as first argument
bool ShapeUtils::evalBroadcastShapeInfo(const NDArray &max, const NDArray &min, const bool evalMinMax, const Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace) {
return evalBroadcastShapeInfo(max.shapeInfo(), min.shapeInfo(), evalMinMax, resultShapeInfo, workspace);
}
bool ShapeUtils::evalBroadcastShapeInfo(const Nd4jLong *max, const Nd4jLong *min, const bool evalMinMax, const Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace) {
// check whether broadcast operation is possible for input arrays
if(!areShapesBroadcastable(max, min))
return false;
return true;
}
auto maxShapeInfo = max; //max.shapeInfo();
auto minShapeInfo = min; //min.shapeInfo();
//////////////////////////////////////////////////////////////////////////
// check the possibility of broadcast operation, if true then return shapeInfo of resulting array
// if evalMinMax == false the array with larger rank has to be passed as first argument
bool ShapeUtils::evalBroadcastShapeInfo(const NDArray &max, const NDArray &min, const bool evalMinMax, Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace) {
return evalBroadcastShapeInfo(max.getShapeInfo(), min.getShapeInfo(), evalMinMax, resultShapeInfo, workspace);
}
if(evalMinMax && (shape::rank(max) < shape::rank(min))) {
maxShapeInfo = min;
minShapeInfo = max;
}
bool ShapeUtils::evalBroadcastShapeInfo(Nd4jLong *max, Nd4jLong *min, const bool evalMinMax, Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace) {
const auto maxRank = shape::rank(maxShapeInfo);
const auto minRank = shape::rank(minShapeInfo);
// check whether broadcast operation is possible for input arrays
if(!areShapesBroadcastable(max, min))
return false;
// evaluate shapeInfo for resulting array
if(resultShapeInfo != nullptr)
throw std::runtime_error("std::runtime_error(ShapeUtils::evalBroadcastShapeInfo method: the input pointer on shapeInfo must be empty (=nullptr) !");
auto maxShapeInfo = max; //max.getShapeInfo();
auto minShapeInfo = min; //min.getShapeInfo();
Nd4jLong *tmpShapeInfo = nullptr;
ALLOCATE(tmpShapeInfo, workspace, shape::shapeInfoLength(maxRank), Nd4jLong);
if(evalMinMax && (shape::rank(max) < shape::rank(min))) {
maxShapeInfo = min;
minShapeInfo = max;
// FIXME: get rid of memcpy here
memcpy(tmpShapeInfo, maxShapeInfo, shape::shapeInfoByteLength(maxRank));
for (int i = 0; i < minRank; ++i)
if((maxShapeInfo[maxRank-i] != 0 && maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i]) || minShapeInfo[minRank-i] == 0)
tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i];
ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo));
if (shape::isEmpty(max) || shape::isEmpty(min)) {
ArrayOptions::setPropertyBit(tmpShapeInfo, ARRAY_EMPTY);
memset(shape::stride(tmpShapeInfo), 0, shape::rank(tmpShapeInfo) * sizeof(Nd4jLong));
}
ShapeDescriptor descriptor(tmpShapeInfo);
RELEASE(tmpShapeInfo, workspace);
resultShapeInfo = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
return true;
}
const auto maxRank = shape::rank(maxShapeInfo);
const auto minRank = shape::rank(minShapeInfo);
//////////////////////////////////////////////////////////////////////////
// check the possibility of broadcast operation for set of arrays, if true then return resulting broadcasted shapeInfo
bool ShapeUtils::evalCommonBroadcastShapeInfo(const std::vector<const NDArray*>& arrays, Nd4jLong*& resultShapeInfo, memory::Workspace* workspace) {
// evaluate shapeInfo for resulting array
if(resultShapeInfo != nullptr)
throw std::runtime_error("std::runtime_error(ShapeUtils::evalBroadcastShapeInfo method: the input pointer on shapeInfo must be empty (=nullptr) !");
if(resultShapeInfo != nullptr)
throw std::runtime_error("ShapeUtils::evalCommonBroadcastShapeInfo method: the input pointer on shapeInfo must be empty (=nullptr) !");
Nd4jLong *tmpShapeInfo = nullptr;
ALLOCATE(tmpShapeInfo, workspace, shape::shapeInfoLength(maxRank), Nd4jLong);
int size = arrays.size();
int maxRank = arrays[size - 1]->rankOf();
// FIXME: get rid of memcpy here
memcpy(tmpShapeInfo, maxShapeInfo, shape::shapeInfoByteLength(maxRank));
for (int i = 0; i < minRank; ++i)
if((maxShapeInfo[maxRank-i] != 0 && maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i]) || minShapeInfo[minRank-i] == 0)
tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i];
for(int i = 0; i < size - 1; ++i) {
if(arrays[i]->rankOf() > maxRank)
maxRank = arrays[i]->rankOf();
for(int j = i + 1; j < size; ++j)
if(!areShapesBroadcastable(*arrays[i], *arrays[j]))
return false;
}
ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo));
Nd4jLong *tmpShapeInfo = nullptr;
ALLOCATE(tmpShapeInfo, workspace, shape::shapeInfoLength(maxRank), Nd4jLong);
memset(tmpShapeInfo, 0, shape::shapeInfoByteLength(maxRank));
tmpShapeInfo[0] = maxRank;
if (shape::isEmpty(max) || shape::isEmpty(min)) {
ArrayOptions::setPropertyBit(tmpShapeInfo, ARRAY_EMPTY);
memset(shape::stride(tmpShapeInfo), 0, shape::rank(tmpShapeInfo) * sizeof(Nd4jLong));
for(const auto& item : arrays ) {
for(int i = -1; i >= -item->rankOf(); --i)
if(tmpShapeInfo[i + 1 + maxRank] < item->sizeAt(i))
tmpShapeInfo[i + 1 + maxRank] = item->sizeAt(i);
}
shape::updateStrides(tmpShapeInfo, arrays[0]->ordering());
ArrayOptions::setDataType(tmpShapeInfo, arrays[0]->dataType());
ShapeDescriptor descriptor(tmpShapeInfo);
RELEASE(tmpShapeInfo, workspace);
resultShapeInfo = const_cast<Nd4jLong*>(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor));
return true;
}
ShapeDescriptor descriptor(tmpShapeInfo);
RELEASE(tmpShapeInfo, workspace);
resultShapeInfo = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
return true;
}
//////////////////////////////////////////////////////////////////////////
// return sorted vector of dimensions common (same) for two arrays, dimensions values corresponds to array with bigger rank
// for example if arr1{2,7}, arr2{2,5,4,7} then vector = {0,3}
std::vector<int> ShapeUtils::getDimsWithSameShape(const NDArray& arr1, const NDArray& arr2) {
//////////////////////////////////////////////////////////////////////////
// check the possibility of broadcast operation for set of arrays, if true then return resulting broadcasted shapeInfo
bool ShapeUtils::evalCommonBroadcastShapeInfo(const std::vector<const NDArray*>& arrays, Nd4jLong*& resultShapeInfo, memory::Workspace* workspace) {
const NDArray *min, *max;
if(resultShapeInfo != nullptr)
throw std::runtime_error("ShapeUtils::evalCommonBroadcastShapeInfo method: the input pointer on shapeInfo must be empty (=nullptr) !");
if(arr1.rankOf() >= arr2.rankOf()) {
max = &arr1;
min = &arr2;
}
else {
max = &arr2;
min = &arr1;
}
int size = arrays.size();
int maxRank = arrays[size - 1]->rankOf();
const int rankDiff = max->rankOf() - min->rankOf();
for(int i = 0; i < size - 1; ++i) {
if(arrays[i]->rankOf() > maxRank)
maxRank = arrays[i]->rankOf();
for(int j = i + 1; j < size; ++j)
if(!areShapesBroadcastable(*arrays[i], *arrays[j]))
return false;
std::vector<int> dims;
for (int i = 0; i < min->rankOf(); ++i)
if (min->sizeAt(i) == max->sizeAt(rankDiff + i))
dims.emplace_back(rankDiff + i);
return dims;
}
Nd4jLong *tmpShapeInfo = nullptr;
ALLOCATE(tmpShapeInfo, workspace, shape::shapeInfoLength(maxRank), Nd4jLong);
memset(tmpShapeInfo, 0, shape::shapeInfoByteLength(maxRank));
tmpShapeInfo[0] = maxRank;
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo for resulting array from tile operation
const Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd4jLong>& reps, sd::memory::Workspace* workspace) {
// check whether reps contains at least one zero (then throw exception) or whether all elements in reps are unities (then simply reshape or do nothing)
int repsSize = reps.size();
Nd4jLong product = 1;
for(const auto& item : reps)
product *= item;
if(product == 0)
throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !");
for(const auto& item : arrays ) {
for(int i = -1; i >= -item->rankOf(); --i)
if(tmpShapeInfo[i + 1 + maxRank] < item->sizeAt(i))
tmpShapeInfo[i + 1 + maxRank] = item->sizeAt(i);
int rankOld = arr.rankOf();
int diff = rankOld - repsSize;
// evaluate new shapeInfo
Nd4jLong* newShapeInfo = nullptr;
if(diff < 0) {
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(repsSize), Nd4jLong);
newShapeInfo[0] = repsSize; // set new rank
for(int i=1; i <= -diff; ++i)
newShapeInfo[i] = 1; // set unities to be new dimensions at left-hand side of newShapeInfo shape place
memcpy(newShapeInfo + 1 - diff, arr.shapeInfo() + 1, rankOld*sizeof(Nd4jLong)); // copy old dimensions to the right-hand side of newShapeInfo shape place
for(int i=1; i <= repsSize; ++i)
newShapeInfo[i] *= reps[i - 1]; // set new shape by multiplying old dimensions by corresponding numbers from reps
}
else {
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rankOld), Nd4jLong);
memcpy(newShapeInfo, arr.shapeInfo(), shape::shapeInfoByteLength(rankOld)); // copy all elements of _shapeInfo to newShapeInfo
for(int i=1; i <= repsSize; ++i)
newShapeInfo[rankOld + 1 - i] *= reps[repsSize - i]; // set new shape by multiplying old dimensions by corresponding numbers from reps
}
shape::updateStrides(newShapeInfo, arr.ordering());
ArrayOptions::setDataType(newShapeInfo, arr.dataType());
ShapeDescriptor descriptor(newShapeInfo);
RELEASE(newShapeInfo, workspace);
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
shape::updateStrides(tmpShapeInfo, arrays[0]->ordering());
ArrayOptions::setDataType(tmpShapeInfo, arrays[0]->dataType());
ShapeDescriptor descriptor(tmpShapeInfo);
RELEASE(tmpShapeInfo, workspace);
resultShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(descriptor);
return true;
}
//////////////////////////////////////////////////////////////////////////
// return sorted vector of dimensions common (same) for two arrays, dimensions values corresponds to array with bigger rank
// for example if arr1{2,7}, arr2{2,5,4,7} then vector = {0,3}
std::vector<int> ShapeUtils::getDimsWithSameShape(const NDArray& arr1, const NDArray& arr2) {
const NDArray *min, *max;
if(arr1.rankOf() >= arr2.rankOf()) {
max = &arr1;
min = &arr2;
}
else {
max = &arr2;
min = &arr1;
}
const int rankDiff = max->rankOf() - min->rankOf();
std::vector<int> dims;
for (int i = 0; i < min->rankOf(); ++i)
if (min->sizeAt(i) == max->sizeAt(rankDiff + i))
dims.emplace_back(rankDiff + i);
return dims;
}
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo for resulting array from tile operation
Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd4jLong>& reps, sd::memory::Workspace* workspace) {
// check whether reps contains at least one zero (then throw exception) or whether all elements in reps are unities (then simply reshape or do nothing)
int repsSize = reps.size();
Nd4jLong product = 1;
for(const auto& item : reps)
product *= item;
if(product == 0)
throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !");
int rankOld = arr.rankOf();
int diff = rankOld - repsSize;
// evaluate new shapeInfo
Nd4jLong* newShapeInfo = nullptr;
if(diff < 0) {
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(repsSize), Nd4jLong);
newShapeInfo[0] = repsSize; // set new rank
for(int i=1; i <= -diff; ++i)
newShapeInfo[i] = 1; // set unities to be new dimensions at left-hand side of newShapeInfo shape place
memcpy(newShapeInfo + 1 - diff, arr.getShapeInfo() + 1, rankOld*sizeof(Nd4jLong)); // copy old dimensions to the right-hand side of newShapeInfo shape place
for(int i=1; i <= repsSize; ++i)
newShapeInfo[i] *= reps[i - 1]; // set new shape by multiplying old dimensions by corresponding numbers from reps
}
else {
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rankOld), Nd4jLong);
memcpy(newShapeInfo, arr.getShapeInfo(), shape::shapeInfoByteLength(rankOld)); // copy all elements of _shapeInfo to newShapeInfo
for(int i=1; i <= repsSize; ++i)
newShapeInfo[rankOld + 1 - i] *= reps[repsSize - i]; // set new shape by multiplying old dimensions by corresponding numbers from reps
}
shape::updateStrides(newShapeInfo, arr.ordering());
ArrayOptions::setDataType(newShapeInfo, arr.dataType());
ShapeDescriptor descriptor(newShapeInfo);
RELEASE(newShapeInfo, workspace);
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
std::vector<Nd4jLong> ShapeUtils::pullShapeFromShapeInfo(Nd4jLong *shapeInfo) {
std::vector<Nd4jLong> ShapeUtils::pullShapeFromShapeInfo(const Nd4jLong *shapeInfo) {
std::vector<Nd4jLong> shape(shape::rank(shapeInfo));
int shapeSize = shape.size();
@ -624,7 +624,7 @@ Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd
std::string ShapeUtils::strideAsString(const NDArray* array) {
std::string result;
auto shapeBuffer = array->getShapeInfo(); //Nd4jLong*
auto shapeBuffer = array->shapeInfo(); //Nd4jLong*
int rank = (int)*shapeBuffer;
result.append("[");
for (int e = 0; e < rank; e++) {
@ -724,31 +724,31 @@ std::vector<Nd4jLong> ShapeUtils::shapeAsVector(const Nd4jLong* shapeInfo) {
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
Nd4jLong* ShapeUtils::evalDiagShapeInfo(const Nd4jLong* shapeInfoConst, sd::memory::Workspace* workspace){
auto shapeInfo = const_cast<Nd4jLong*>(shapeInfoConst);
const Nd4jLong* ShapeUtils::evalDiagShapeInfo(const Nd4jLong* shapeInfoConst, sd::memory::Workspace* workspace){
auto shapeInfo = const_cast<Nd4jLong*>(shapeInfoConst);
const auto rank = shape::rank(shapeInfo);
const auto rank = shape::rank(shapeInfo);
Nd4jLong* outputShapeInfo = nullptr;
Nd4jLong* outputShapeInfo = nullptr;
if(shape::isVector(shapeInfo) || shape::isScalar(shapeInfo)) {
ALLOCATE(outputShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong);
outputShapeInfo[0] = 2;
outputShapeInfo[1] = outputShapeInfo[2] = shape::length(shapeInfo);
if(shape::isVector(shapeInfo) || shape::isScalar(shapeInfo)) {
ALLOCATE(outputShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong);
outputShapeInfo[0] = 2;
outputShapeInfo[1] = outputShapeInfo[2] = shape::length(shapeInfo);
}
else {
ALLOCATE(outputShapeInfo, workspace, shape::shapeInfoLength(2*rank), Nd4jLong);
outputShapeInfo[0] = 2*rank;
for(int i = 1; i <= rank; ++i)
outputShapeInfo[i] = outputShapeInfo[i + rank] = shapeInfo[i];
}
ShapeUtils::updateStridesAndType(outputShapeInfo, shapeInfo, shape::order(shapeInfo));
auto result = ConstantShapeHelper::getInstance()->createShapeInfo(outputShapeInfo);
RELEASE(outputShapeInfo, workspace);
return result;
}
else {
ALLOCATE(outputShapeInfo, workspace, shape::shapeInfoLength(2*rank), Nd4jLong);
outputShapeInfo[0] = 2*rank;
for(int i = 1; i <= rank; ++i)
outputShapeInfo[i] = outputShapeInfo[i + rank] = shapeInfo[i];
}
ShapeUtils::updateStridesAndType(outputShapeInfo, shapeInfo, shape::order(shapeInfo));
auto result = ConstantShapeHelper::getInstance()->createShapeInfo(outputShapeInfo);
RELEASE(outputShapeInfo, workspace);
return result;
}
std::vector<int> ShapeUtils::evalBroadcastBackwardAxis(const Nd4jLong *operandShapeInfo, const Nd4jLong *resultShapeInfo) {
// rRank >= oRank always !!
@ -765,83 +765,82 @@ std::vector<int> ShapeUtils::evalBroadcastBackwardAxis(const Nd4jLong *operandSh
}
////////////////////////////////////////////////////////////////////////////////
Nd4jLong* ShapeUtils::matrixProductShape(Nd4jLong* theFirstShape, Nd4jLong* theSecondShape, bool shouldTranspondFirst, bool shouldTranspondSecond, sd::DataType dtype, sd::memory::Workspace* workspace) {
const Nd4jLong* ShapeUtils::matrixProductShape(const Nd4jLong* theFirstShape, const Nd4jLong* theSecondShape, bool shouldTranspondFirst, bool shouldTranspondSecond, sd::DataType dtype, sd::memory::Workspace* workspace) {
auto inA = theFirstShape;
auto inB = theSecondShape;
Nd4jLong *shape;
ALLOCATE(shape, workspace, shape::shapeInfoLength(2), Nd4jLong);
auto inA = theFirstShape;
auto inB = theSecondShape;
Nd4jLong *shape;
ALLOCATE(shape, workspace, shape::shapeInfoLength(2), Nd4jLong);
Nd4jLong* tmpA = ShapeBuilders::copyShapeInfo(inA, true, workspace);
Nd4jLong* tmpB = ShapeBuilders::copyShapeInfo(inB, true, workspace);
Nd4jLong* tmpA = ShapeBuilders::copyShapeInfo(inA, true, workspace);
Nd4jLong* tmpB = ShapeBuilders::copyShapeInfo(inB, true, workspace);
if (shouldTranspondFirst)
shape::transposeInplace(tmpA);
if (shouldTranspondFirst)
shape::transposeInplace(tmpA);
if (shouldTranspondSecond)
shape::transposeInplace(tmpB);
if (shouldTranspondSecond)
shape::transposeInplace(tmpB);
if (shape::rank(tmpA) == 1 && shape::isMatrix(tmpB)) {
// special case here
shape[0] = 1;
shape[1] = tmpB[2];
Nd4jLong *newShape = ShapeBuilders::createShapeInfo(dtype, 'f', 2, shape, workspace);
RELEASE(shape, workspace);
RELEASE(tmpA, workspace);
RELEASE(tmpB, workspace);
return newShape;
} else if (shape::isScalar(tmpA) && shape::isScalar(tmpB)) {
// just scalar vs scalar
shape[0] = 1;
shape[1] = 1;
} else if (shape::isMatrix(tmpA) && shape::isVector(tmpB)) {
// gemv case
if (shape::rank(tmpB) == 2) {
shape[0] = tmpA[1];
if (shape::rank(tmpA) == 1 && shape::isMatrix(tmpB)) {
// special case here
shape[0] = 1;
shape[1] = tmpB[2];
} else {
// we have new 1D shape here
auto newShape = ShapeBuilders::createVectorShapeInfo(dtype, tmpA[1], workspace);
Nd4jLong *newShape = ShapeBuilders::createShapeInfo(dtype, 'f', 2, shape, workspace);
RELEASE(shape, workspace);
RELEASE(tmpA, workspace);
RELEASE(tmpB, workspace);
return newShape;
} else if (shape::isScalar(tmpA) && shape::isScalar(tmpB)) {
// just scalar vs scalar
shape[0] = 1;
shape[1] = 1;
} else if (shape::isMatrix(tmpA) && shape::isVector(tmpB)) {
// gemv case
if (shape::rank(tmpB) == 2) {
shape[0] = tmpA[1];
shape[1] = tmpB[2];
} else {
// we have new 1D shape here
auto newShape = ShapeBuilders::createVectorShapeInfo(dtype, tmpA[1], workspace);
RELEASE(shape, workspace);
RELEASE(tmpA, workspace);
RELEASE(tmpB, workspace);
return newShape;
}
} else if ((shape::isMatrix(tmpA) && shape::isMatrix(tmpB)) ||
(shape::isVector(tmpA) && shape::isMatrix(tmpB)) ||
(shape::isColumnVector(tmpA) && shape::isVector(tmpB))) {
// gemm case
shape[0] = tmpA[1];
shape[1] = tmpB[2];
} else if ((shape::isVector(tmpA) && shape::isScalar(tmpB)) ||
(shape::isScalar(tmpA) && shape::isVector(tmpB))) {
// element-wise
shape[0] = 1;
shape[1] = (int) sd::math::nd4j_max<Nd4jLong>(shape::length(tmpA), shape::length(tmpB));
} else if (shape::isRowVector(tmpA) && shape::isRowVector(tmpB)) {
// dot case
shape[0] = 1;
shape[1] = 1;
} else if (shape::isRowVector(tmpA) && shape::isColumnVector(tmpB)) {
// dot case
shape[0] = 1;
shape[1] = 1;
}
} else if ((shape::isMatrix(tmpA) && shape::isMatrix(tmpB)) ||
(shape::isVector(tmpA) && shape::isMatrix(tmpB)) ||
(shape::isColumnVector(tmpA) && shape::isVector(tmpB))) {
// gemm case
shape[0] = tmpA[1];
shape[1] = tmpB[2];
} else if ((shape::isVector(tmpA) && shape::isScalar(tmpB)) ||
(shape::isScalar(tmpA) && shape::isVector(tmpB))) {
// element-wise
shape[0] = 1;
shape[1] = (int) sd::math::nd4j_max<Nd4jLong>(shape::length(tmpA), shape::length(tmpB));
} else if (shape::isRowVector(tmpA) && shape::isRowVector(tmpB)) {
// dot case
shape[0] = 1;
shape[1] = 1;
} else if (shape::isRowVector(tmpA) && shape::isColumnVector(tmpB)) {
// dot case
shape[0] = 1;
shape[1] = 1;
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'f', 2, shape);
RELEASE(shape, workspace);
RELEASE(tmpA, workspace);
RELEASE(tmpB, workspace);
return newShape;
}
Nd4jLong *newShape = ShapeBuilders::createShapeInfo(dtype, 'f', 2, shape, workspace);
RELEASE(shape, workspace);
RELEASE(tmpA, workspace);
RELEASE(tmpB, workspace);
return newShape;
}
////////////////////////////////////////////////////////////////////////////////
std::vector<int> ShapeUtils::evalPermutFromTo(const std::vector<Nd4jLong>& shapeFrom, const std::vector<Nd4jLong>& shapeTo) {
auto rank = shapeFrom.size();

View File

@ -65,7 +65,7 @@ namespace shape {
* the information on an ndarray
*/
struct ND4J_EXPORT ShapeInformation {
_CUDA_HD ShapeInformation(Nd4jLong *shape_ = nullptr, Nd4jLong *stride_ = nullptr, char order_ = 0, int rank_ = 0, int offset_ = 0, int elementWiseStride_ = 0)
_CUDA_HD ShapeInformation(Nd4jLong* shape_ = nullptr, Nd4jLong *stride_ = nullptr, char order_ = 0, int rank_ = 0, int offset_ = 0, int elementWiseStride_ = 0)
: shape(shape_), stride(stride_), order(order_), rank(rank_), offset(offset_), elementWiseStride(elementWiseStride_)
{}
@ -93,19 +93,19 @@ namespace shape {
ND4J_EXPORT _CUDA_HD bool shapeEquals(const int shape1Rank, const Nd4jLong *shape1, const int shape2Rank, const Nd4jLong *shape2);
ND4J_EXPORT _CUDA_HD Nd4jLong* detachShape(Nd4jLong *originalShape);
ND4J_EXPORT _CUDA_HD const Nd4jLong* detachShape(const Nd4jLong *originalShape);
ND4J_EXPORT _CUDA_HD Nd4jLong* copyShape(Nd4jLong *originalShape);
ND4J_EXPORT _CUDA_HD Nd4jLong* copyShape(Nd4jLong const* originalShape);
ND4J_EXPORT _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2);
ND4J_EXPORT _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3);
ND4J_EXPORT _CUDA_HD bool strideEquals(int shape1Rank,Nd4jLong *shape1,int shape2Rank,Nd4jLong *shape2);
ND4J_EXPORT _CUDA_HD bool strideEquals(int const shape1Rank,Nd4jLong const* shape1,int const shape2Rank, Nd4jLong const* shape2);
ND4J_EXPORT _CUDA_HD bool strideEquals(Nd4jLong *shapeInfo1,Nd4jLong *shapeInfo2);
ND4J_EXPORT _CUDA_HD bool strideEquals(Nd4jLong const* shapeInfo1, Nd4jLong const* shapeInfo2);
ND4J_EXPORT _CUDA_HD bool strideEquals(Nd4jLong *stride1,int rank1,Nd4jLong *stride2,int rank2);
ND4J_EXPORT _CUDA_HD bool strideEquals(Nd4jLong const* stride1,int const rank1, Nd4jLong const* stride2, int const rank2);
ND4J_EXPORT _CUDA_HD bool equalsSoft(const Nd4jLong *shapeA, const Nd4jLong *shapeB);
@ -128,7 +128,7 @@ namespace shape {
ND4J_EXPORT _CUDA_HD int tadIndexForLinear(int linearIndex, int tadLength);
ND4J_EXPORT _CUDA_HD Nd4jLong tadLength(Nd4jLong *shapeInfo, int *dimension, int dimensionLength);
ND4J_EXPORT _CUDA_HD Nd4jLong tadLength(const Nd4jLong *shapeInfo, int *dimension, int dimensionLength);
ND4J_EXPORT _CUDA_HD bool canReshape(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShape, bool isFOrder);
@ -142,17 +142,17 @@ namespace shape {
* Get the shape info buffer
* for the given rank and shape.
*/
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong *shape);
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong const* shape);
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong *shape, Nd4jLong *buffer);
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong const* shape, Nd4jLong *buffer);
/**
* Get the shape info buffer
* for the given rank and shape.
*/
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong *shape);
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong const* shape);
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong *shape, Nd4jLong *output);
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong const* shape, Nd4jLong *output);
#ifdef __CUDACC__
@ -168,9 +168,9 @@ namespace shape {
* @param startNum the start number for the strides
* @return the strides for a matrix of n dimensions
*/
ND4J_EXPORT _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong *shape, int rank);
ND4J_EXPORT _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank);
ND4J_EXPORT _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong *shape, int rank, Nd4jLong* ret);
ND4J_EXPORT _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank, Nd4jLong* ret);
/**
* Computes the standard packed array strides for a given shape.
@ -180,9 +180,9 @@ namespace shape {
* @return the strides for a matrix of n dimensions
*/
ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank);
ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank);
ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank, Nd4jLong* ret);
ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank, Nd4jLong* ret);
ND4J_EXPORT _CUDA_HD void updateStrides(Nd4jLong *shape, const char order);
ND4J_EXPORT _CUDA_HD void updateStrides(const int rank, const Nd4jLong *shapeOnly, Nd4jLong *stridesOnly, const char order);
@ -199,9 +199,9 @@ namespace shape {
* @param startNum the start number for the strides
* @return the strides for a matrix of n dimensions
*/
ND4J_EXPORT _CUDA_HD Nd4jLong* calcStridesFortran(Nd4jLong *shape, int rank, int startNum);
ND4J_EXPORT _CUDA_HD Nd4jLong* calcStridesFortran(Nd4jLong const *shape, int rank, int startNum);
ND4J_EXPORT _CUDA_HD Nd4jLong* calcStridesFortran(Nd4jLong *shape, int rank, int startNum, Nd4jLong* ret);
ND4J_EXPORT _CUDA_HD Nd4jLong* calcStridesFortran(Nd4jLong const *shape, int rank, int startNum, Nd4jLong* ret);
/**
* Computes the standard packed array strides for a given shape.
@ -210,9 +210,9 @@ namespace shape {
* @param startNum the start number for the strides
* @return the strides for a matrix of n dimensions
*/
ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank, int startNum);
ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const* shape, int rank, int startNum);
ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank, int startNum, Nd4jLong* ret);
ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank, int startNum, Nd4jLong* ret);
/**
* @param toCopy the shape to copy
@ -244,7 +244,7 @@ namespace shape {
* @return 0 if there is no element wise stride the
* element wise stride of reshape(1,length) otherwise
*/
ND4J_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong *shape, Nd4jLong *stride, int isFOrder);
ND4J_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const* shape, Nd4jLong const* stride, int isFOrder);
/**
* Compute the element wise stride
@ -257,11 +257,11 @@ namespace shape {
* @return 0 if there is no element wise stride the
* element wise stride of reshape(1,length) otherwise
*/
ND4J_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong *shape, Nd4jLong *stride, int isFOrder, Nd4jLong *dimension, int dimensionLength);
ND4J_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const* shape, Nd4jLong const* stride, int isFOrder, Nd4jLong const* dimension, int dimensionLength);
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride);
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(Nd4jLong const* shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride);
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride, Nd4jLong *buffer);
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(const Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride, Nd4jLong *buffer);
/**
*
* @param length
@ -281,7 +281,7 @@ namespace shape {
*/
ND4J_EXPORT _CUDA_HD void doPermuteSwap(int length, Nd4jLong **shape, int* rearrange);
ND4J_EXPORT _CUDA_HD Nd4jLong *permuteShapeBuffer(Nd4jLong *shapeBuffer, int* rearrange);
ND4J_EXPORT _CUDA_HD Nd4jLong *permuteShapeBuffer(Nd4jLong const* shapeBuffer, int* rearrange);
ND4J_EXPORT _CUDA_HD void permuteShapeBufferInPlace(Nd4jLong *shapeBuffer, int* rearrange, Nd4jLong *out);
@ -304,7 +304,7 @@ namespace shape {
ND4J_EXPORT _CUDA_HD Nd4jLong* createPermuteIndexes(int originalRank, int *dimension,int dimensionLength);
ND4J_EXPORT _CUDA_HD Nd4jLong* computeResultShape(Nd4jLong *originalShapeBuffer, int *dimension,int dimensionLength);
ND4J_EXPORT _CUDA_HD Nd4jLong* computeResultShape(const Nd4jLong *originalShapeBuffer, int *dimension,int dimensionLength);
/**
* This method does inplace transpose of given shapeBuffer
@ -350,7 +350,7 @@ namespace shape {
* @param shape the shape of the array
* @param rank the rank of cthe shape
*/
ND4J_EXPORT _CUDA_HD int isVector(Nd4jLong *shape, int rank);
ND4J_EXPORT _CUDA_HD int isVector(Nd4jLong const* shape, int rank);
/**
@ -363,13 +363,13 @@ namespace shape {
ND4J_EXPORT _CUDA_HD int isVector(const Nd4jLong *shapeInfo);
ND4J_EXPORT _CUDA_HD bool isLikeVector(Nd4jLong *shapeInfo, int& posOfNonUnityDim);
ND4J_EXPORT _CUDA_HD bool isLikeVector(Nd4jLong const* shapeInfo, int& posOfNonUnityDim);
ND4J_EXPORT _CUDA_HD bool isCommonVector(const Nd4jLong *shapeInfo, int& posOfNonUnityDim);
ND4J_EXPORT _CUDA_HD bool isRowVector(const Nd4jLong *shapeInfo);
ND4J_EXPORT _CUDA_HD bool isColumnVector(Nd4jLong *shapeInfo);
ND4J_EXPORT _CUDA_HD bool isColumnVector(Nd4jLong const* shapeInfo);
/**
* shape - input inShape is shape only, not shapeInfo
@ -401,10 +401,10 @@ namespace shape {
*/
template <typename T>
ND4J_EXPORT _CUDA_HD T* copyOf(Nd4jLong length, T *toCopy);
ND4J_EXPORT _CUDA_HD T* copyOf(Nd4jLong length, T const* toCopy);
template <typename T>
ND4J_EXPORT _CUDA_HD T* copyOf(Nd4jLong length, T *toCopy, T *ret);
ND4J_EXPORT _CUDA_HD T* copyOf(Nd4jLong length, T const* toCopy, T *ret);
/**
* Return a copy of a buffer.
@ -413,13 +413,13 @@ namespace shape {
*/
template <typename T>
ND4J_EXPORT _CUDA_HD void copyTo(Nd4jLong length, T *from, T *to);
ND4J_EXPORT _CUDA_HD void copyTo(Nd4jLong length, T const* from, T *to);
/**
* Return a copy of a buffer.
* This buffer allocates memory
* that must be freed elsewhere.
*/
ND4J_EXPORT _CUDA_HD void copyTo(int length, Nd4jLong *from, Nd4jLong *to, Nd4jLong *indexes);
ND4J_EXPORT _CUDA_HD void copyTo(int length, Nd4jLong const* from, Nd4jLong *to, Nd4jLong *indexes);
/**
* Permute the given strides
@ -566,7 +566,7 @@ namespace shape {
* item
*/
template <typename T1, typename T2>
ND4J_EXPORT _CUDA_HD void removeIndex(T1 *data, T2 *indexes, Nd4jLong dataLength, Nd4jLong indexesLength, T1 *out);
ND4J_EXPORT _CUDA_HD void removeIndex(T1 const* data, T2 const* indexes, Nd4jLong dataLength, Nd4jLong indexesLength, T1 *out);
/**
* Return a copy of this array with the
@ -582,7 +582,7 @@ namespace shape {
*/
template <typename T1, typename T2>
ND4J_EXPORT _CUDA_HD T1* removeIndex(T1 *data, T2 *indexes, Nd4jLong dataLength, Nd4jLong indexesLength);
ND4J_EXPORT _CUDA_HD T1* removeIndex(T1 const* data, T2 const* indexes, Nd4jLong dataLength, Nd4jLong indexesLength);
/**
* Iterate over a given set of indexes
@ -595,7 +595,7 @@ namespace shape {
* indexes should be the indexes to exclude
* indexes length should be the length of indexes
*/
ND4J_EXPORT _CUDA_HD Nd4jLong* everyIndexBut(Nd4jLong *indexes,int indexesLength,int begin,int end);
ND4J_EXPORT _CUDA_HD Nd4jLong* everyIndexBut(Nd4jLong const* indexes,int indexesLength,int begin,int end);
/**
* Computes the offset for accessing
@ -641,7 +641,7 @@ namespace shape {
* Keep the given indexes
* in the data
*/
ND4J_EXPORT _CUDA_HD Nd4jLong *keep(volatile Nd4jLong *data, int* index, int indexLength, int dataLength);
ND4J_EXPORT _CUDA_HD Nd4jLong *keep(volatile Nd4jLong *data, int const* index, int indexLength, int dataLength);
/**
* Generate reverse copy of the data
@ -651,13 +651,13 @@ namespace shape {
*/
template <typename T>
ND4J_EXPORT _CUDA_HD T* reverseCopy(T *data, Nd4jLong length);
ND4J_EXPORT _CUDA_HD T* reverseCopy(T const* data, Nd4jLong length);
template <typename T>
ND4J_EXPORT _CUDA_HD void reverseCopyTo(T *from, T *to, Nd4jLong length);
ND4J_EXPORT _CUDA_HD void reverseCopyTo(T const* from, T *to, Nd4jLong length);
template <typename T>
ND4J_EXPORT _CUDA_HD void reverseCopyTo(T *from, T *to, Nd4jLong *indexes, Nd4jLong length);
ND4J_EXPORT _CUDA_HD void reverseCopyTo(T const* from, T *to, Nd4jLong *indexes, Nd4jLong length);
template <typename T1, typename T2>
ND4J_EXPORT _CUDA_H void convertT(T1 *from, T2 *to, Nd4jLong length);
@ -670,7 +670,7 @@ namespace shape {
* @return
*/
template <typename T>
ND4J_EXPORT _CUDA_HD T* concat(T* arr1, Nd4jLong arr1Length, T* arr2, Nd4jLong arr2Length);
ND4J_EXPORT _CUDA_HD T* concat(T const* arr1, Nd4jLong const arr1Length, T const* arr2, Nd4jLong const arr2Length);
/**
*
@ -681,7 +681,7 @@ namespace shape {
* @return
*/
template <typename T>
ND4J_EXPORT _CUDA_HD T* concat(int numArrays, int numTotalElements, Nd4jLong **arr, Nd4jLong *lengths);
ND4J_EXPORT _CUDA_HD T* concat(int const numArrays, int const numTotalElements, Nd4jLong const**arr, Nd4jLong const* lengths);
/**
* Get the length per slice of the
@ -695,7 +695,7 @@ namespace shape {
* @return the length per slice of the given shape
* along the given dimension
*/
ND4J_EXPORT _CUDA_HD Nd4jLong lengthPerSlice(int rank, Nd4jLong *shape, int *dimension, int dimensionLength);
ND4J_EXPORT _CUDA_HD Nd4jLong lengthPerSlice(int rank, Nd4jLong const* shape, int const* dimension, int dimensionLength);
/**
* calculates the offset for a tensor
@ -706,10 +706,10 @@ namespace shape {
*/
ND4J_EXPORT _CUDA_HD Nd4jLong sliceOffsetForTensor(int rank,
int index,
Nd4jLong *shape,
Nd4jLong *tensorShape,
Nd4jLong const* shape,
Nd4jLong const* tensorShape,
int tensorShapeLength,
int *dimension,
int const *dimension,
int dimensionLength);
/**
@ -1095,7 +1095,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
* Length of a tad given
* the shape information
*/
INLINEDEF _CUDA_HD Nd4jLong tadLength(Nd4jLong *shapeInfo, int *dimension, int dimensionLength) {
INLINEDEF _CUDA_HD Nd4jLong tadLength(const Nd4jLong *shapeInfo, int *dimension, int dimensionLength) {
if(dimensionLength == 1) {
return shape::shapeOf(shapeInfo)[dimension[0]];
}
@ -1166,7 +1166,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
}
INLINEDEF _CUDA_HD bool strideEquals(int shape1Rank,Nd4jLong *shape1,int shape2Rank,Nd4jLong *shape2) {
INLINEDEF _CUDA_HD bool strideEquals(int const shape1Rank, Nd4jLong const* shape1,int const shape2Rank,Nd4jLong const* shape2) {
if(shape1Rank != shape2Rank)
return false;
//rank not equals
@ -1178,12 +1178,12 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
return true;
}
INLINEDEF _CUDA_HD bool strideEquals(Nd4jLong *shapeInfo1,Nd4jLong *shapeInfo2) {
INLINEDEF _CUDA_HD bool strideEquals(Nd4jLong const* shapeInfo1,Nd4jLong const* shapeInfo2) {
return shape::strideEquals(shape::rank(shapeInfo1),shape::stride(shapeInfo1),shape::rank(shapeInfo2),shape::stride(shapeInfo2));
}
INLINEDEF _CUDA_HD bool strideEquals(Nd4jLong *stride1,int rank1 , Nd4jLong *stride2, int rank2) {
INLINEDEF _CUDA_HD bool strideEquals(Nd4jLong const* stride1,int const rank1 , Nd4jLong const* stride2, int const rank2) {
if(rank1 != rank2)
return false;
@ -1195,7 +1195,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
return true;
}
INLINEDEF _CUDA_HD Nd4jLong *computeResultShape(Nd4jLong *originalShapeBuffer, int* dimension,int dimensionLength) {
INLINEDEF _CUDA_HD Nd4jLong *computeResultShape(Nd4jLong const* originalShapeBuffer, int * dimension,int dimensionLength) {
Nd4jLong *retShape;
int retShapeLength;
if(dimensionLength == 1 && dimension[0] == 2147483647) {
@ -1236,7 +1236,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
}
INLINEDEF _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride, Nd4jLong *buffer) {
INLINEDEF _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(const Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride, Nd4jLong *buffer) {
Nd4jLong *theShape = shape::shapeOf(shapeInfo);
Nd4jLong *theStride = shape::stride(shapeInfo);
int rank = dimensionLength == 1 ? 2 : dimensionLength;
@ -1279,7 +1279,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
}
else {
Nd4jLong *newIndexes = dimension;
Nd4jLong *newIndexes = dimension;
if(reverseCopyStride)
shape::reverseCopyTo(theStride, retStride, newIndexes, len);
else
@ -1293,7 +1293,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
return ret;
}
INLINEDEF _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride) {
INLINEDEF _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(const Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride) {
int rank = dimensionLength == 1 ? 2 : dimensionLength;
traceNew(4);
@ -1330,7 +1330,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
* @param startNum the start number for the strides
* @return the strides for a matrix of n dimensions
*/
INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong *shape, int rank, int startNum) {
INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank, int startNum) {
if (isVector(shape, rank)) {
traceNew(5);
@ -1356,7 +1356,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
return stride;
}
INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong *shape, int rank, int startNum, Nd4jLong *ret) {
INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank, int startNum, Nd4jLong *ret) {
if (isVector(shape, rank)) {
for (int i = 0; i < rank; i++)
ret[i] = 1;
@ -1382,7 +1382,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
* @param startNum the start number for the strides
* @return the strides for a matrix of n dimensions
*/
INLINEDEF _CUDA_HD Nd4jLong * calcStrides(Nd4jLong *shape, int rank, int startNum) {
INLINEDEF _CUDA_HD Nd4jLong * calcStrides(Nd4jLong const *shape, int rank, int startNum) {
traceNew(7);
@ -1410,7 +1410,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
return stride;
}
INLINEDEF _CUDA_HD Nd4jLong * calcStrides(Nd4jLong *shape, int rank, int startNum, Nd4jLong* ret) {
INLINEDEF _CUDA_HD Nd4jLong * calcStrides(Nd4jLong const* shape, int rank, int startNum, Nd4jLong* ret) {
if (rank == 1) {
ret[0] = 1;
return ret;
@ -1439,11 +1439,11 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
* @param startNum the start number for the strides
* @return the strides for a matrix of n dimensions
*/
INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong *shape, int rank) {
INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank) {
return calcStridesFortran(shape, rank, 1);
}
INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong *shape, int rank, Nd4jLong* ret) {
INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank, Nd4jLong* ret) {
return calcStridesFortran(shape, rank, 1, ret);
}
@ -1454,11 +1454,11 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
* @param startNum the start number for the strides
* @return the strides for a matrix of n dimensions
*/
INLINEDEF _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank) {
INLINEDEF _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank) {
return calcStrides(shape, rank, 1);
}
INLINEDEF _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank, Nd4jLong* ret) {
INLINEDEF _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank, Nd4jLong* ret) {
return calcStrides(shape, rank, 1, ret);
}
@ -1541,7 +1541,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
return copy;
}
INLINEDEF _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong *shape, Nd4jLong *stride, int isFOrder) {
INLINEDEF _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const* shape, Nd4jLong const* stride, int isFOrder) {
if (rank == 0)
return 1;
@ -1690,8 +1690,8 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
}
INLINEDEF _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong *shape, Nd4jLong *stride, int isFOrder,
Nd4jLong *dimension, int dimensionLength) {
INLINEDEF _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const* shape, Nd4jLong const* stride, int isFOrder,
Nd4jLong const* dimension, int dimensionLength) {
if(dimensionLength == 1) {
return stride[dimension[0]];
}
@ -1703,13 +1703,13 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
* Get the shape info buffer
* for the given rank and shape.
*/
INLINEDEF _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong *shape) {
INLINEDEF _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong const* shape) {
Nd4jLong *stride = shape::calcStrides(shape, rank);
traceNew(11);
auto shapeInfo = new shape::ShapeInformation();
shapeInfo->shape = shape;
shapeInfo->shape = const_cast<Nd4jLong*>(shape);
shapeInfo->stride = stride;
shapeInfo->offset = 0;
shapeInfo->rank = rank;
@ -1728,13 +1728,13 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
*
* This method is used only for SoftMax
*/
INLINEDEF _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong *shape, Nd4jLong *buffer) {
INLINEDEF _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong const* shape, Nd4jLong *buffer) {
Nd4jLong stride[MAX_RANK];
shape::calcStrides(shape,rank, stride);
shape::ShapeInformation shapeInfo;
shapeInfo.shape = shape;
shapeInfo.shape = const_cast<Nd4jLong*>(shape);
shapeInfo.stride = stride;
shapeInfo.offset = 0;
shapeInfo.rank = rank;
@ -1751,13 +1751,13 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
* Get the shape info buffer
* for the given rank and shape.
*/
INLINEDEF _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong *shape) {
INLINEDEF _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong const* shape) {
auto stride = shape::calcStridesFortran(shape,rank);
traceNew(12);
auto shapeInfo = new shape::ShapeInformation();
shapeInfo->shape = shape;
shapeInfo->shape = const_cast<Nd4jLong*>(shape);
shapeInfo->stride = stride;
shapeInfo->offset = 0;
shapeInfo->rank = rank;
@ -1772,13 +1772,13 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
return shapeInfoBuffer;
}
INLINEDEF _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong *shape, Nd4jLong *output) {
INLINEDEF _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong const *shape, Nd4jLong *output) {
Nd4jLong stride[MAX_RANK];
shape::calcStridesFortran(shape,rank, stride);
shape::ShapeInformation shapeInfo;
shapeInfo.shape = shape;
shapeInfo.shape = const_cast<Nd4jLong*>(shape);
shapeInfo.stride = stride;
shapeInfo.offset = 0;
shapeInfo.rank = rank;
@ -2049,7 +2049,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
shape::doPermuteShapeInfo(out, rearrange);
}
INLINEDEF _CUDA_HD Nd4jLong *permuteShapeBuffer(Nd4jLong *shapeBuffer, int* rearrange) {
INLINEDEF _CUDA_HD Nd4jLong *permuteShapeBuffer(Nd4jLong const* shapeBuffer, int* rearrange) {
auto len = shape::shapeInfoLength(shape::rank(shapeBuffer));
Nd4jLong *copy = shape::copyOf(len, shapeBuffer);
shape::doPermuteShapeInfo(copy,rearrange);
@ -2238,7 +2238,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
* @param shape the shape of the array
* @param rank the rank of the shape
*/
INLINEDEF _CUDA_HD int isVector(Nd4jLong *shape, int rank) {
INLINEDEF _CUDA_HD int isVector(Nd4jLong const* shape, int rank) {
if (rank == 0)
return 0;
@ -2254,7 +2254,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
return 0;
}
INLINEDEF _CUDA_HD bool isLikeVector(Nd4jLong *shapeInfo, int& posOfNonUnityDim) {
INLINEDEF _CUDA_HD bool isLikeVector(Nd4jLong const* shapeInfo, int& posOfNonUnityDim) {
int numOfNonUnity = 0;
for(int i = 1; i <= shapeInfo[0]; ++i) {
@ -2284,7 +2284,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
return numOfNonUnity == 1;
}
INLINEDEF _CUDA_H Nd4jLong* detachShape(Nd4jLong *originalShape) {
INLINEDEF _CUDA_H Nd4jLong const* detachShape(Nd4jLong const* originalShape) {
Nd4jLong *newShape = new Nd4jLong[shape::shapeInfoLength(originalShape)];
memcpy(newShape, originalShape, shape::shapeInfoByteLength(originalShape));
@ -2292,7 +2292,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
}
INLINEDEF _CUDA_H Nd4jLong* copyShape(Nd4jLong *originalShape) {
INLINEDEF _CUDA_H Nd4jLong* copyShape(Nd4jLong const* originalShape) {
Nd4jLong *newShape = new Nd4jLong[shape::shapeInfoLength(originalShape)];
memcpy(newShape, originalShape, shape::shapeInfoByteLength(originalShape));
@ -2309,7 +2309,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
return isVector && shapeFirstOne;
}
INLINEDEF _CUDA_HD bool isColumnVector(Nd4jLong *shapeInfo) {
INLINEDEF _CUDA_HD bool isColumnVector(const Nd4jLong *shapeInfo) {
bool isVector = shape::isVector(shapeInfo) == 1;
bool shapeFirstOne = shape::shapeOf(shapeInfo)[0] == 1;
return isVector && !shapeFirstOne;
@ -2381,7 +2381,7 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape)
* that must be freed elsewhere.
*/
template <typename T>
INLINEDEF _CUDA_HD T *copyOf(Nd4jLong length, T *toCopy) {
INLINEDEF _CUDA_HD T *copyOf(Nd4jLong length, T const* toCopy) {
traceNew(18);
T *ret = new T[length];
@ -2389,7 +2389,7 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape)
}
template <typename T>
INLINEDEF _CUDA_HD T* copyOf(Nd4jLong length, T *toCopy, T *ret) {
INLINEDEF _CUDA_HD T* copyOf(Nd4jLong length, T const* toCopy, T *ret) {
memcpy(ret, toCopy, sizeof(T)*length);
return ret;
}
@ -2400,7 +2400,7 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape)
* that must be freed elsewhere.
*/
template <typename T>
INLINEDEF _CUDA_HD void copyTo(Nd4jLong length, T *from, T *to) {
INLINEDEF _CUDA_HD void copyTo(Nd4jLong length, T const* from, T *to) {
memcpy(to, from, sizeof(T)*length);
}
@ -2409,7 +2409,7 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape)
* This buffer allocates memory
* that must be freed elsewhere.
*/
INLINEDEF _CUDA_HD void copyTo(int length, Nd4jLong *from, Nd4jLong *to, Nd4jLong *indexes) {
INLINEDEF _CUDA_HD void copyTo(int length, Nd4jLong const* from, Nd4jLong *to, Nd4jLong *indexes) {
for(int i = 0; i < length; i++) {
to[i] = from[indexes[i]];
}
@ -2817,7 +2817,7 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape)
* item
*/
template <typename T1, typename T2>
INLINEDEF _CUDA_HD void removeIndex(T1* data, T2 *indexes, Nd4jLong dataLength, Nd4jLong indexesLength, T1 *ret) {
INLINEDEF _CUDA_HD void removeIndex(T1 const* data, T2 const* indexes, Nd4jLong dataLength, Nd4jLong indexesLength, T1 *ret) {
int count = 0;
int absLength = dataLength - indexesLength;
@ -2850,7 +2850,7 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape)
* item
*/
template <typename T1, typename T2>
INLINEDEF _CUDA_HD T1* removeIndex(T1 *data, T2 *indexes, Nd4jLong dataLength, Nd4jLong indexesLength) {
INLINEDEF _CUDA_HD T1* removeIndex(T1 const* data, T2 const* indexes, Nd4jLong dataLength, Nd4jLong indexesLength) {
auto lengthOfArr = dataLength - indexesLength;
if(lengthOfArr < 0) {
printf("Remove index call created a <= 0 length array. This was likely not intended.");
@ -2862,7 +2862,7 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape)
return ret;
}
INLINEDEF _CUDA_HD Nd4jLong* everyIndexBut(Nd4jLong *indexes,int indexesLength,int begin,int end) {
INLINEDEF _CUDA_HD Nd4jLong* everyIndexBut(const Nd4jLong *indexes,int indexesLength,int begin,int end) {
int len = end - indexesLength;
traceNew(20);
@ -3086,7 +3086,7 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
* @param dataLength
* @return
*/
INLINEDEF _CUDA_HD Nd4jLong *keep(volatile Nd4jLong *data, int* index, int indexLength, int dataLength) {
INLINEDEF _CUDA_HD Nd4jLong *keep(volatile Nd4jLong *data, int const* index, int indexLength, int dataLength) {
traceNew(23);
@ -3113,7 +3113,7 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
*/
template <typename T>
INLINEDEF _CUDA_HD T* reverseCopy(T *data, Nd4jLong length) {
INLINEDEF _CUDA_HD T* reverseCopy(T const* data, Nd4jLong length) {
if (length < 1)
return nullptr;
@ -3129,7 +3129,7 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
}
template <typename T>
INLINEDEF _CUDA_HD void reverseCopyTo(T *from, T *to, Nd4jLong length) {
INLINEDEF _CUDA_HD void reverseCopyTo(T const* from, T *to, Nd4jLong length) {
if (length < 1)
return;
for (Nd4jLong i = 0; i <= length / 2; i++) {
@ -3140,7 +3140,7 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
}
template <typename T>
INLINEDEF _CUDA_HD void reverseCopyTo(T *from, T *to, Nd4jLong *indexes, Nd4jLong length) {
INLINEDEF _CUDA_HD void reverseCopyTo(T const* from, T *to, Nd4jLong *indexes, Nd4jLong length) {
if (length < 1)
return;
@ -3161,7 +3161,7 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
* @return
*/
template <typename T>
INLINEDEF _CUDA_HD T* concat(T* arr1, Nd4jLong arr1Length, T* arr2, Nd4jLong arr2Length) {
INLINEDEF _CUDA_HD T* concat(T const* arr1, Nd4jLong const arr1Length, T const* arr2, Nd4jLong const arr2Length) {
traceNew(25);
@ -3180,7 +3180,7 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
* @return
*/
template <typename T>
INLINEDEF _CUDA_HD T *concat(Nd4jLong numArrays, Nd4jLong numTotalElements, T **arr, Nd4jLong *lengths) {
INLINEDEF _CUDA_HD T *concat(Nd4jLong const numArrays, Nd4jLong const numTotalElements, T const **arr, Nd4jLong const *lengths) {
T* ret = new T[numTotalElements];
Nd4jLong count = 0;
@ -3206,7 +3206,7 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
* @return the length per slice of the given shape
* along the given dimension
*/
INLINEDEF _CUDA_HD Nd4jLong lengthPerSlice(int rank, Nd4jLong *shape, int* dimension, int dimensionLength) {
INLINEDEF _CUDA_HD Nd4jLong lengthPerSlice(int rank, Nd4jLong const* shape, int const* dimension, int dimensionLength) {
if(shape::isVector(shape,rank)) {
//return total length for row vectors
if(dimensionLength == 1 && shape[0] == 1) {
@ -3230,7 +3230,7 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
* @param tensorShape
* @return
*/
INLINEDEF _CUDA_HD Nd4jLong sliceOffsetForTensor(int rank, int index, Nd4jLong *shape, Nd4jLong *tensorShape, int tensorShapeLength, int* dimension, int dimensionLength) {
INLINEDEF _CUDA_HD Nd4jLong sliceOffsetForTensor(int rank, int index, Nd4jLong const* shape, Nd4jLong const* tensorShape, int tensorShapeLength, int const* dimension, int dimensionLength) {
auto tensorLength = prodLong(tensorShape, tensorShapeLength);
auto lengthPerSlice2 = lengthPerSlice(rank, shape, dimension, dimensionLength);
if (lengthPerSlice2 <= 0) {

View File

@ -47,11 +47,11 @@ public:
*/
static void execIndexReduceScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo);
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo);
/**
*
@ -68,13 +68,13 @@ public:
*/
static void execReduce3Scalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo);
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo);
/**
@ -90,13 +90,13 @@ public:
*/
static void execReduce3(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo);
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo);
/**
*
@ -113,29 +113,29 @@ public:
*/
static void execReduce3(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *xTadOnlyShapeInfo, Nd4jLong *xTadOffsets,
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets);
const Nd4jLong *xTadOnlyShapeInfo, const Nd4jLong *xTadOffsets,
const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets);
static void execReduce3All(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets);
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets);
/**
*
@ -150,13 +150,13 @@ public:
*/
static void execIndexReduce(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
/**
*
@ -170,73 +170,76 @@ public:
* @param n
*/
static void execScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hSscalarShapeInfo,
void *dScalar, Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism = true);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
const void *hScalar, const Nd4jLong *hSscalarShapeInfo,
const void *dScalar, const Nd4jLong *dSscalarShapeInfo,
void *extraParams,
bool allowParallelism = true);
static void execScalarBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hSscalarShapeInfo,
void *dScalar, Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism = true);
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
const void *hScalar, const Nd4jLong *hSscalarShapeInfo,
const void *dScalar, const Nd4jLong *dSscalarShapeInfo,
void *extraParams,
bool allowParallelism = true);
static void execScalarInt(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hSscalarShapeInfo,
void *dScalar, Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism = true);
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
const void *hScalar, const Nd4jLong *hSscalarShapeInfo,
const void *dScalar, const Nd4jLong *dSscalarShapeInfo,
void *extraParams,
bool allowParallelism = true);
static void execScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
void const* hScalars, Nd4jLong const* hScalarShapeInfo,
void const* dScalars, Nd4jLong const* dScalarShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ);
static void execScalarBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
const void *hScalars, const Nd4jLong *hScalarShapeInfo,
const void *dScalars, const Nd4jLong *dScalarShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ);
static void execScalarInt(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
const void *hScalars, const Nd4jLong *hScalarShapeInfo,
const void *dScalars, const Nd4jLong *dScalarShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ);
/**
@ -252,105 +255,107 @@ static void execScalarInt(sd::LaunchContext *lc,
* @param dimensionLength
*/
static void execBroadcast(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ);
static void execBroadcast(sd::LaunchContext* lc,
const int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo);
static void execInverseBroadcast(sd::LaunchContext *lc,
int opNum,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *result, Nd4jLong *resultShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
const void *x, const Nd4jLong *xShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
static void execBroadcastBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ);
static void execBroadcastBool(sd::LaunchContext* lc, const int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
static void execBroadcastBool(sd::LaunchContext* lc,
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams);
static void execInverseBroadcastBool(sd::LaunchContext *lc,
int opNum,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *result, Nd4jLong *resultShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
static void execBroadcastInt(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ);
static void execBroadcastInt(sd::LaunchContext* lc, const int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo);
static void execBroadcastInt(sd::LaunchContext* lc,
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo);
static void execInverseBroadcastInt(sd::LaunchContext *lc,
int opNum,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *result, Nd4jLong *resultShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
/**
*
@ -365,34 +370,34 @@ static void execScalarInt(sd::LaunchContext *lc,
* @param n
*/
static void execPairwiseTransform(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams);
static void execPairwiseBoolTransform(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams);
static void execPairwiseIntTransform(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams);
/**
*
@ -405,49 +410,50 @@ static void execScalarInt(sd::LaunchContext *lc,
* @param n
*/
static void execTransformFloat(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
static void execTransformAny(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism = true);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
bool allowParallelism = true);
static void execTransformStrict(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
static void execTransformSame(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
static void execTransformBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
/**
*
* @param opNum
@ -458,44 +464,44 @@ static void execTransformBool(sd::LaunchContext *lc,
* @param resultShapeInfo
*/
static void execReduceFloat(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
static void execReduceSame(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
static void execReduceBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
static void execReduceLong(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
/**
*
@ -506,49 +512,49 @@ static void execTransformBool(sd::LaunchContext *lc,
* @return
*/
static void execReduceFloatScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo);
static void execReduceBoolScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo);
static void execReduceSameScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo);
static void execReduceLongScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo);
static void execReduce3TAD(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yTadOffsets);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffsets);
/**
*
@ -562,15 +568,15 @@ static void execTransformBool(sd::LaunchContext *lc,
* @param dimensionLength
*/
static void execSummaryStats(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
bool biasCorrected);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
bool biasCorrected);
/**
*
@ -582,13 +588,13 @@ static void execTransformBool(sd::LaunchContext *lc,
* @param resultShapeInfo
*/
static void execSummaryStats(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
bool biasCorrected);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
bool biasCorrected);
/**
*
@ -600,68 +606,51 @@ static void execTransformBool(sd::LaunchContext *lc,
* @param resultShapeInfo
*/
static void execSummaryStatsScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
bool biasCorrected);
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
bool biasCorrected);
static void execRandom(sd::LaunchContext *lc,
int opNum,
Nd4jPointer state,
void *hZ, Nd4jLong *hZShapeBuffer,
void *dZ, Nd4jLong *dZShapeBuffer,
void *extraArguments);
int opNum,
Nd4jPointer state,
void *hZ, const Nd4jLong *hZShapeBuffer,
void *dZ, const Nd4jLong *dZShapeBuffer,
void *extraArguments);
static void execRandom(sd::LaunchContext *lc,
int opNum,
Nd4jPointer state,
void *hX, Nd4jLong *hXShapeBuffer,
void *dX, Nd4jLong *dXShapeBuffer,
void *hZ, Nd4jLong *hZShapeBuffer,
void *dZ, Nd4jLong *dZShapeBuffer,
void *extraArguments);
int opNum,
Nd4jPointer state,
const void *hX, const Nd4jLong *hXShapeBuffer,
const void *dX, const Nd4jLong *dXShapeBuffer,
void *hZ, const Nd4jLong *hZShapeBuffer,
void *dZ, const Nd4jLong *dZShapeBuffer,
void *extraArguments);
static void execRandom(sd::LaunchContext *lc,
int opNum,
Nd4jPointer state,
void *hX, Nd4jLong *hXShapeBuffer,
void *dX, Nd4jLong *dXShapeBuffer,
void *hY, Nd4jLong *hYShapeBuffer,
void *dY, Nd4jLong *dYShapeBuffer,
void *hZ, Nd4jLong *hZShapeBuffer,
void *dZ, Nd4jLong *dZShapeBuffer,
void *extraArguments);
int opNum,
Nd4jPointer state,
const void *hX, const Nd4jLong *hXShapeBuffer,
const void *dX, const Nd4jLong *dXShapeBuffer,
const void *hY, const Nd4jLong *hYShapeBuffer,
const void *dY, const Nd4jLong *dYShapeBuffer,
void *hZ, const Nd4jLong *hZShapeBuffer,
void *dZ, const Nd4jLong *dZShapeBuffer,
void *extraArguments);
template <typename X>
static FORCEINLINE void execAggregate(sd::LaunchContext *lc,
int opNum,
void **varguments,
int numArguments,
Nd4jLong **shapeArguments,
int numShapeArguments,
int *indexArguments,
int numIndexArguments,
int **intArrays,
int numIntArrays,
void *vrealArguments,
int numRealArguments) {
}
inline static void execSort(void *x, Nd4jLong *xShapeInfo, bool descending) {
inline static void execSort(void *x, const Nd4jLong *xShapeInfo, bool descending) {
auto xType = sd::ArrayOptions::dataType(xShapeInfo);
BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::sortGeneric(x, xShapeInfo, descending), LIBND4J_TYPES);
}
static void execSort(void *x, Nd4jLong *xShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool descending) {
static void execSort(void *x, const Nd4jLong *xShapeInfo, int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, bool descending) {
auto xType = sd::ArrayOptions::dataType(xShapeInfo);
BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::sortTadGeneric(x, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES);
@ -672,13 +661,13 @@ static void execTransformBool(sd::LaunchContext *lc,
}
inline static Nd4jLong encodeBitmap(void *dx, Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold) {
inline static Nd4jLong encodeBitmap(void *dx, const Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold) {
auto xType = sd::ArrayOptions::dataType(xShapeInfo);
BUILD_SINGLE_SELECTOR(xType, return sd::SpecialMethods, ::encodeBitmapGeneric(dx, xShapeInfo, N, dz, threshold), FLOAT_TYPES);
}
inline static void decodeBitmap(void *dx, Nd4jLong N, void *dz, Nd4jLong *zShapeInfo) {
inline static void decodeBitmap(const void *dx, Nd4jLong N, void *dz, const Nd4jLong *zShapeInfo) {
auto zType = sd::ArrayOptions::dataType(zShapeInfo);
BUILD_SINGLE_SELECTOR(zType, sd::SpecialMethods, ::decodeBitmapGeneric(dx, N, dz, zShapeInfo), FLOAT_TYPES);

View File

@ -122,9 +122,9 @@ ND4J_EXPORT void setTADThreshold(int num);
*/
ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo);
/**
*
@ -139,10 +139,10 @@ ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape);
/**
*
@ -159,20 +159,20 @@ ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers,
ND4J_EXPORT void execBroadcast(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape);
ND4J_EXPORT void execBroadcastBool(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape);
/**
*
@ -189,17 +189,17 @@ ND4J_EXPORT void execBroadcastBool(
ND4J_EXPORT void execPairwiseTransform(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams);
ND4J_EXPORT void execPairwiseTransformBool(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams);
/**
@ -213,28 +213,28 @@ ND4J_EXPORT void execPairwiseTransformBool(
*/
ND4J_EXPORT void execReduceFloat(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo);
ND4J_EXPORT void execReduceSame(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo);
ND4J_EXPORT void execReduceBool(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo);
ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo);
/**
*
@ -247,34 +247,34 @@ ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execReduceFloat2(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape);
ND4J_EXPORT void execReduceSame2(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape);
ND4J_EXPORT void execReduceBool2(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape);
ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape);
/**
*
@ -289,10 +289,10 @@ ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParamsVals,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo);
/**
*
@ -305,10 +305,10 @@ ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParamsVals,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo);
/**
*
* @param opNum
@ -324,24 +324,24 @@ ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execReduce3Tad(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParamsVals,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets);
OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets);
ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParamsVals,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets);
OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
Nd4jLong const* xTadShapeInfo, Nd4jLong const* xOffsets,
Nd4jLong const* yTadShapeInfo, Nd4jLong const* yOffsets);
/**
*
@ -356,16 +356,16 @@ ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execScalar(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalar, Nd4jLong *hSscalarShapeInfo, Nd4jLong *dSscalarShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbScalar, Nd4jLong const* hSscalarShapeInfo, Nd4jLong const* dSscalarShapeInfo,
void *extraParams);
ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalar, Nd4jLong *hSscalarShapeInfo, Nd4jLong *dSscalarShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbScalar, Nd4jLong const* hSscalarShapeInfo, Nd4jLong const* dSscalarShapeInfo,
void *extraParams);
/**
@ -377,9 +377,9 @@ ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
bool biasCorrected);
/**
*
@ -392,9 +392,9 @@ ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
bool biasCorrected);
/**
*
@ -409,12 +409,12 @@ ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
bool biasCorrected,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets);
/**
*
@ -428,32 +428,32 @@ ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execTransformFloat(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams);
ND4J_EXPORT void execTransformSame(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams);
ND4J_EXPORT void execTransformBool(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams);
ND4J_EXPORT void execTransformAny(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams);
ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams);
/**
@ -471,23 +471,23 @@ ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execScalarTad(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ);
ND4J_EXPORT void execScalarBoolTad(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ);
ND4J_EXPORT void specialConcat (
Nd4jPointer *extraPointers,
@ -496,7 +496,7 @@ ND4J_EXPORT void specialConcat (
Nd4jPointer *data,
Nd4jPointer *inputShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
Nd4jLong const* resultShapeInfo,
Nd4jPointer *tadPointers,
Nd4jPointer *offsetPointers);
@ -792,14 +792,14 @@ typedef sd::TadPack OpaqueTadPack;
* @param targetBuffer
* @param offsetsBuffer
*/
ND4J_EXPORT OpaqueTadPack* tadOnlyShapeInfo(Nd4jLong *xShapeInfo,
ND4J_EXPORT OpaqueTadPack* tadOnlyShapeInfo(Nd4jLong const*xShapeInfo,
int *dimension,
int dimensionLength);
ND4J_EXPORT Nd4jLong* getPrimaryShapeInfo(OpaqueTadPack* pack);
ND4J_EXPORT Nd4jLong* getPrimaryOffsets(OpaqueTadPack* pack);
ND4J_EXPORT Nd4jLong* getSpecialShapeInfo(OpaqueTadPack* pack);
ND4J_EXPORT Nd4jLong* getSpecialOffsets(OpaqueTadPack* pack);
ND4J_EXPORT Nd4jLong const* getPrimaryShapeInfo(OpaqueTadPack* pack);
ND4J_EXPORT Nd4jLong const* getPrimaryOffsets(OpaqueTadPack* pack);
ND4J_EXPORT Nd4jLong const* getSpecialShapeInfo(OpaqueTadPack* pack);
ND4J_EXPORT Nd4jLong const* getSpecialOffsets(OpaqueTadPack* pack);
ND4J_EXPORT Nd4jLong getNumberOfTads(OpaqueTadPack* pack);
ND4J_EXPORT int getShapeInfoLength(OpaqueTadPack* pack);
@ -824,14 +824,14 @@ ND4J_EXPORT void deleteTadPack(OpaqueTadPack* ptr);
* @param zTadOffsets
*/
ND4J_EXPORT void pullRows(Nd4jPointer *extraPointers,
OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *zShapeInfo, Nd4jLong *dzShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* xShapeInfo, Nd4jLong const* dxShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* zShapeInfo, Nd4jLong const* dzShapeInfo,
Nd4jLong n,
Nd4jLong *indexes,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffsets);
Nd4jLong const* tadShapeInfo,
Nd4jLong const* tadOffsets,
Nd4jLong const* zTadShapeInfo,
Nd4jLong const* zTadOffsets);
/**
*
@ -843,20 +843,20 @@ ND4J_EXPORT void pullRows(Nd4jPointer *extraPointers,
* @param propagate
*/
ND4J_EXPORT void average(Nd4jPointer *extras,
Nd4jPointer *x, Nd4jLong *xShapeInfo,
Nd4jPointer *dx, Nd4jLong *dxShapeInfo,
void *z, Nd4jLong *zShapeInfo,
void *dz, Nd4jLong *dzShapeInfo,
Nd4jPointer *x, Nd4jLong const* xShapeInfo,
Nd4jPointer *dx, Nd4jLong const* dxShapeInfo,
void *z, Nd4jLong const* zShapeInfo,
void *dz, Nd4jLong const* dzShapeInfo,
int n,
Nd4jLong length,
bool propagate);
ND4J_EXPORT void accumulate(Nd4jPointer *extras,
Nd4jPointer *x, Nd4jLong *xShapeInfo,
Nd4jPointer *dx, Nd4jLong *dxShapeInfo,
void *z, Nd4jLong *zShapeInfo,
void *dz, Nd4jLong *dzShapeInfo,
Nd4jPointer *x, Nd4jLong const* xShapeInfo,
Nd4jPointer *dx, Nd4jLong const* dxShapeInfo,
void *z, Nd4jLong const* zShapeInfo,
void *dz, Nd4jLong const* dzShapeInfo,
int n,
Nd4jLong length);
@ -1004,7 +1004,7 @@ ND4J_EXPORT void execAggregateBatch(Nd4jPointer *extraPointers,
ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers,
int opNum,
Nd4jPointer state,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeBuffer, Nd4jLong const* dZShapeBuffer,
void *extraArguments);
/**
@ -1023,9 +1023,9 @@ ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers,
ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers,
int opNum,
Nd4jPointer state,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeBuffer, Nd4jLong *dXShapeBuffer,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeBuffer, Nd4jLong *dYShapeBuffer,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeBuffer, Nd4jLong const* dXShapeBuffer,
OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeBuffer, Nd4jLong const* dYShapeBuffer,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeBuffer, Nd4jLong const* dZShapeBuffer,
void *extraArguments);
/**
@ -1042,8 +1042,8 @@ ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers,
ND4J_EXPORT void execRandom2(Nd4jPointer *extraPointers,
int opNum,
Nd4jPointer state,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeBuffer, Nd4jLong *dXShapeBuffer,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeBuffer, Nd4jLong const* dXShapeBuffer,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeBuffer, Nd4jLong const* dZShapeBuffer,
void *extraArguments);
@ -1098,11 +1098,11 @@ ND4J_EXPORT void destroyRandom(Nd4jPointer ptrRandom);
*/
template <typename T>
static Nd4jPointer _numpyHeaderForNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize,Nd4jLong *headerSize) {
Nd4jLong *shapeBufferCast = reinterpret_cast<Nd4jLong *>(shapeBuffer);
static Nd4jPointer _numpyHeaderForNd4j(Nd4jPointer data,const Nd4jPointer shapeBuffer,Nd4jLong wordSize,Nd4jLong* headerSize) {
Nd4jLong const* shapeBufferCast = reinterpret_cast<const Nd4jLong *>(shapeBuffer);
int rank = shape::rank(shapeBufferCast);
Nd4jLong *shape = shape::shapeOf(shapeBufferCast);
unsigned int *npShape = new unsigned int[rank];
const Nd4jLong* shape = shape::shapeOf(shapeBufferCast);
unsigned int* npShape = new unsigned int[rank];
for(int i = 0; i < rank; i++) {
npShape[i] = shape[i];
}
@ -1125,7 +1125,7 @@ static Nd4jPointer _numpyHeaderForNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,
extern "C" {
static Nd4jPointer numpyHeaderForNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize,Nd4jLong *headerSize) {
static Nd4jPointer numpyHeaderForNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize,Nd4jLong* headerSize) {
auto shapeBufferCast = reinterpret_cast<Nd4jLong *>(shapeBuffer);
auto type = sd::ArrayOptions::dataType(shapeBufferCast);
BUILD_SINGLE_SELECTOR(type, return _numpyHeaderForNd4j, (data, shapeBuffer, wordSize, headerSize), LIBND4J_TYPES);
@ -1427,53 +1427,53 @@ ND4J_EXPORT Nd4jPointer pointerForAddress(Nd4jLong address);
* @return
*/
ND4J_EXPORT void tear(Nd4jPointer *extraPointers,
OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo,
Nd4jPointer *targets, Nd4jLong *zShapeInfo,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets);
OpaqueDataBuffer *dbX, Nd4jLong const* xShapeInfo, Nd4jLong const* dxShapeInfo,
Nd4jPointer *targets, Nd4jLong const* zShapeInfo,
Nd4jLong const* tadShapeInfo,
Nd4jLong const* tadOffsets);
ND4J_EXPORT void sort(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *x, Nd4jLong const* xShapeInfo,
void *dx, Nd4jLong const* dxShapeInfo,
bool descending);
ND4J_EXPORT void sortByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
void *x, Nd4jLong const* xShapeInfo,
void *dx, Nd4jLong const* dxShapeInfo,
void *y, Nd4jLong const* yShapeInfo,
void *dy, Nd4jLong const* dyShapeInfo,
bool descending);
ND4J_EXPORT void sortByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
void *x, Nd4jLong const* xShapeInfo,
void *dx, Nd4jLong const* dxShapeInfo,
void *y, Nd4jLong const* yShapeInfo,
void *dy, Nd4jLong const* dyShapeInfo,
bool descending);
ND4J_EXPORT void sortTad(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *x, Nd4jLong const* xShapeInfo,
void *dx, Nd4jLong const* dxShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets,
Nd4jLong const* tadShapeInfo,
Nd4jLong const* tadOffsets,
bool descending);
ND4J_EXPORT void sortTadByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
void *x, Nd4jLong const* xShapeInfo,
void *dx, Nd4jLong const* dxShapeInfo,
void *y, Nd4jLong const* yShapeInfo,
void *dy, Nd4jLong const* dyShapeInfo,
int *dimension,
int dimensionLength,
bool descending);
ND4J_EXPORT void sortTadByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
void *x, Nd4jLong const* xShapeInfo,
void *dx, Nd4jLong const* dxShapeInfo,
void *y, Nd4jLong const* yShapeInfo,
void *dy, Nd4jLong const* dyShapeInfo,
int *dimension,
int dimensionLength,
bool descending);
@ -1509,7 +1509,7 @@ ND4J_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, N
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs);
ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list);
ND4J_EXPORT Nd4jLong* getShape(OpaqueShapeList* list, Nd4jLong i);
ND4J_EXPORT Nd4jLong const* getShape(OpaqueShapeList* list, Nd4jLong i);
ND4J_EXPORT void deleteShapeList(Nd4jPointer shapeList);
@ -1526,7 +1526,7 @@ ND4J_EXPORT OpaqueVariable* getVariable(OpaqueVariablesSet* set, Nd4jLong i);
ND4J_EXPORT int getVariableId(OpaqueVariable* variable);
ND4J_EXPORT int getVariableIndex(OpaqueVariable* variable);
ND4J_EXPORT const char* getVariableName(OpaqueVariable* variable);
ND4J_EXPORT Nd4jLong* getVariableShape(OpaqueVariable* variable);
ND4J_EXPORT Nd4jLong const* getVariableShape(OpaqueVariable* variable);
ND4J_EXPORT void* getVariableBuffer(OpaqueVariable* variable);
ND4J_EXPORT int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId);
@ -1545,7 +1545,7 @@ ND4J_EXPORT void deleteGraphState(Nd4jPointer state);
ND4J_EXPORT void deleteResultWrapper(Nd4jPointer ptr);
ND4J_EXPORT int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer x, Nd4jLong *xShapeInfo, int N, float threshold);
ND4J_EXPORT int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer x, Nd4jLong const* xShapeInfo, int N, float threshold);
// this method executes op that requires scope to be present: if/while/cond/whatever
ND4J_EXPORT Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs);
@ -1557,11 +1557,11 @@ ND4J_EXPORT char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer pt
ND4J_EXPORT void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr);
ND4J_EXPORT void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs,
void* hX, Nd4jLong* hXShapeInfo, Nd4jLong* hXOffsets,
void* dX, Nd4jLong* dXShapeInfo, Nd4jLong* dXOffsets,
void* hY, Nd4jLong* hYShapeInfo, Nd4jLong* hYOffsets,
void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets,
void* hIindexes, Nd4jLong* hIndicesShapeInfo, void* dIindexes, Nd4jLong* dIndicesShapeInfo);
void* hX, Nd4jLong const* hXShapeInfo, Nd4jLong const* hXOffsets,
void* dX, Nd4jLong const* dXShapeInfo, Nd4jLong const* dXOffsets,
void* hY, Nd4jLong const* hYShapeInfo, Nd4jLong const* hYOffsets,
void* dY, Nd4jLong const* dYShapeInfo, Nd4jLong const* dYOffsets,
void* hIindexes, Nd4jLong const* hIndicesShapeInfo, void* dIindexes, Nd4jLong const* dIndicesShapeInfo);
ND4J_EXPORT void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo);
@ -1570,7 +1570,7 @@ typedef sd::ConstantDataBuffer OpaqueConstantDataBuffer;
ND4J_EXPORT OpaqueConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, sd::DataType dtype, char order, Nd4jLong ews, bool empty);
ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferLong(sd::DataType dtype, Nd4jLong *data, int length);
ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferLong(sd::DataType dtype, Nd4jLong const* data, int length);
ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferDouble(sd::DataType dtype, double *data, int length);
ND4J_EXPORT OpaqueConstantDataBuffer* constantBuffer(sd::DataType dtype, sd::ConstantDescriptor *descriptor);

View File

@ -77,11 +77,11 @@
* @param hZShapeInfo
*/
void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc, int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) {
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo) {
@ -106,22 +106,21 @@ void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc, int opNu
*/
void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
Nd4jLong* hz = reinterpret_cast<Nd4jLong*>(hZ);
auto hz = reinterpret_cast<Nd4jLong*>(hZ);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::indexreduce::IndexReduce, ::exec(opNum, hX, hXShapeInfo, extraParams, hz, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES, INDEXING_TYPES);
// BUILD_SINGLE_SELECTOR(xType, functions::indexreduce::IndexReduce, ::exec(opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES);
}
////////////////////////////////////////////////////////////////////////
@ -139,16 +138,16 @@ void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc,
*/
void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) {
@ -230,15 +229,15 @@ void NativeOpExecutioner::execBroadcast(sd::LaunchContext* lc, const int opNum,
void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hYShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -269,17 +268,17 @@ void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) {
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
@ -320,17 +319,17 @@ void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext* lc, const int opN
void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hYShapeInfo);
@ -358,16 +357,16 @@ void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -422,16 +421,16 @@ void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc, const int opN
}
void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hYShapeInfo);
@ -471,14 +470,14 @@ void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc,
* @param n
*/
void NativeOpExecutioner::execPairwiseTransform(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hYShapeInfo);
@ -504,14 +503,14 @@ void NativeOpExecutioner::execPairwiseTransform(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execPairwiseBoolTransform(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -538,14 +537,14 @@ void NativeOpExecutioner::execPairwiseBoolTransform(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execPairwiseIntTransform(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hYShapeInfo);
@ -580,14 +579,14 @@ void NativeOpExecutioner::execPairwiseIntTransform(sd::LaunchContext *lc,
* @param hZShapeInfo
*/
void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
@ -609,14 +608,14 @@ void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -637,14 +636,14 @@ void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -665,14 +664,14 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -701,12 +700,12 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
* @return
*/
void NativeOpExecutioner::execReduceFloatScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -717,12 +716,12 @@ void NativeOpExecutioner::execReduceFloatScalar(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceSameScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -732,14 +731,12 @@ void NativeOpExecutioner::execReduceSameScalar(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceBoolScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -749,13 +746,12 @@ void NativeOpExecutioner::execReduceBoolScalar(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceLongScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -779,14 +775,15 @@ void NativeOpExecutioner::execReduceLongScalar(sd::LaunchContext *lc,
* @param dimensionLength
*/
void NativeOpExecutioner::execReduce3Scalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -807,15 +804,14 @@ void NativeOpExecutioner::execReduce3Scalar(sd::LaunchContext *lc,
* @param hZShapeInfo
*/
void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -826,17 +822,17 @@ void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *xTadOnlyShapeInfo, Nd4jLong *xTadOffsets,
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *xTadOnlyShapeInfo, const Nd4jLong *xTadOffsets,
const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -867,18 +863,17 @@ void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -895,19 +890,17 @@ void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduce3TAD(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yTadOffsets) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -948,14 +941,15 @@ void NativeOpExecutioner::execReduce3TAD(sd::LaunchContext *lc,
* @param n
*/
void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hScalarShapeInfo,
void *dScalar, Nd4jLong *dScalarShapeInfo,
void *extraParams, bool allowParallelism) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
const void *hScalar, const Nd4jLong *hScalarShapeInfo,
const void *dScalar, const Nd4jLong *dScalarShapeInfo,
void *extraParams,
bool allowParallelism) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo);
@ -983,16 +977,16 @@ void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const*hXShapeInfo,
void const* dX, Nd4jLong const*dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo,
void *hZ, Nd4jLong const*hZShapeInfo,
void *dZ, Nd4jLong const*dZShapeInfo,
void const* hScalars, Nd4jLong const*hScalarShapeInfo,
void const* dScalars, Nd4jLong const*dScalarShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
Nd4jLong const*tadShapeInfo, Nd4jLong const*tadOffsets,
Nd4jLong const*tadShapeInfoZ, Nd4jLong const*tadOffsetsZ) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo);
@ -1019,14 +1013,15 @@ void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hSscalarShapeInfo,
void *dScalar, Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
const void *hScalar, const Nd4jLong *hSscalarShapeInfo,
const void *dScalar, const Nd4jLong *dSscalarShapeInfo,
void *extraParams,
bool allowParallelism) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hSscalarShapeInfo);
@ -1052,17 +1047,17 @@ void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
const void *hScalars, const Nd4jLong *hScalarShapeInfo,
const void *dScalars, const Nd4jLong *dScalarShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo);
@ -1087,14 +1082,15 @@ void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hSscalarShapeInfo,
void *dScalar, Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
const void *hScalar, const Nd4jLong *hSscalarShapeInfo,
const void *dScalar, const Nd4jLong *dSscalarShapeInfo,
void *extraParams,
bool allowParallelism) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hSscalarShapeInfo);
@ -1120,17 +1116,17 @@ void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
const void *hScalars, const Nd4jLong *hScalarShapeInfo,
const void *dScalars, const Nd4jLong *dScalarShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo);
@ -1164,13 +1160,13 @@ void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
* @param hZShapeInfo
*/
void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
bool biasCorrected) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
bool biasCorrected) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -1190,13 +1186,13 @@ void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
* @param hZShapeInfo
*/
void NativeOpExecutioner::execSummaryStatsScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
bool biasCorrected) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
bool biasCorrected) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -1218,15 +1214,15 @@ void NativeOpExecutioner::execSummaryStatsScalar(sd::LaunchContext *lc,
* @param dimensionLength
*/
void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
bool biasCorrected) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
bool biasCorrected) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -1246,13 +1242,13 @@ void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
* @param n
*/
void NativeOpExecutioner::execTransformFloat(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -1268,13 +1264,13 @@ void NativeOpExecutioner::execTransformFloat(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execTransformBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -1290,13 +1286,14 @@ void NativeOpExecutioner::execTransformBool(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
bool allowParallelism) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -1319,13 +1316,13 @@ void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execTransformSame(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -1341,13 +1338,13 @@ void NativeOpExecutioner::execTransformSame(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execTransformStrict(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -1363,11 +1360,11 @@ void NativeOpExecutioner::execTransformStrict(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
int opNum,
Nd4jPointer state,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraArguments) {
int opNum,
Nd4jPointer state,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraArguments) {
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -1380,14 +1377,13 @@ void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
int opNum,
Nd4jPointer state,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraArguments) {
int opNum,
Nd4jPointer state,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraArguments) {
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -1399,16 +1395,15 @@ void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
int opNum,
Nd4jPointer state,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraArguments) {
int opNum,
Nd4jPointer state,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraArguments) {
auto xType = sd::ArrayOptions::dataType(hZShapeInfo);

View File

@ -102,9 +102,9 @@ void setTADThreshold(int num) {
*/
void execIndexReduceScalar(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) {
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
try {
NativeOpExecutioner::execIndexReduceScalar(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo);
} catch (std::exception &e) {
@ -125,10 +125,10 @@ void execIndexReduceScalar(Nd4jPointer *extraPointers,
* @param dimensionLength
*/
void execIndexReduce(Nd4jPointer *extraPointers,int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) {
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
@ -176,18 +176,16 @@ void execIndexReduce(Nd4jPointer *extraPointers,int opNum,
*/
void execBroadcast(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) {
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension,
dimensionLength);
auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(hZShapeInfo, dimension,
dimensionLength);
auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(hZShapeInfo, dimension, dimensionLength);
auto hTADShapeInfo = tadPackX.primaryShapeInfo();
auto hTADOffsets = tadPackX.primaryOffsets();
@ -216,19 +214,17 @@ void execBroadcast(Nd4jPointer *extraPointers,
void execBroadcastBool(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) {
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension,
dimensionLength);
auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(hZShapeInfo, dimension,
dimensionLength);
auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(hZShapeInfo, dimension, dimensionLength);
auto hTADShapeInfo = tadPackX.primaryShapeInfo();
auto hTADOffsets = tadPackX.primaryOffsets();
@ -272,9 +268,9 @@ void execBroadcastBool(Nd4jPointer *extraPointers,
void execPairwiseTransform(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraParams) {
try {
NativeOpExecutioner::execPairwiseTransform(nullptr,
@ -301,9 +297,9 @@ void execPairwiseTransform(
void execPairwiseTransformBool(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraParams) {
try {
@ -340,9 +336,9 @@ void execPairwiseTransformBool(
void execReduceFloat(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) {
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
try {
NativeOpExecutioner::execReduceFloatScalar(nullptr,
@ -365,9 +361,9 @@ void execReduceFloat(
void execReduceSame(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) {
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
try {
NativeOpExecutioner::execReduceSameScalar(nullptr,
@ -390,9 +386,9 @@ void execReduceSame(
void execReduceBool(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) {
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
try {
NativeOpExecutioner::execReduceBoolScalar(nullptr,
opNum,
@ -414,9 +410,9 @@ void execReduceBool(
void execReduceLong(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) {
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
try {
NativeOpExecutioner::execReduceLongScalar(nullptr,
opNum,
@ -446,16 +442,15 @@ void execReduceLong(
*/
void execReduceFloat2(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) {
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension,
dimensionLength);
auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
auto hTADShapeInfo = tadPackX.primaryShapeInfo();
auto hTADOffsets = tadPackX.primaryOffsets();
@ -482,13 +477,13 @@ void execReduceFloat2(Nd4jPointer *extraPointers,
void execReduceBool2(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) {
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension,
dimensionLength);
@ -518,10 +513,10 @@ void execReduceBool2(Nd4jPointer *extraPointers,
void execReduceSame2(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) {
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
@ -554,16 +549,15 @@ void execReduceSame2(Nd4jPointer *extraPointers,
void execReduceLong2(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) {
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension,
dimensionLength);
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
auto hTADShapeInfo = tadPack.primaryShapeInfo();
auto hTADOffsets = tadPack.primaryOffsets();
@ -601,10 +595,10 @@ void execReduceLong2(Nd4jPointer *extraPointers,
*/
void execReduce3(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) {
OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
try {
NativeOpExecutioner::execReduce3(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo,
dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo);
@ -624,10 +618,10 @@ void execReduce3(Nd4jPointer *extraPointers,
* @param hYShapeInfo
*/
void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) {
OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
try {
NativeOpExecutioner::execReduce3Scalar(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbY->primary(),
hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo);
@ -651,16 +645,16 @@ void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum,
*/
void execReduce3Tad(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets) {
OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets) {
try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
if (extraPointers == nullptr || extraPointers[2] == 0) {
NativeOpExecutioner::execReduce3(LaunchContext::defaultContext(), opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo,
@ -704,9 +698,9 @@ bool isBlasVersionMatches(int major, int minor, int build) {
void execScalar(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalar, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalar, const Nd4jLong *hScalarShapeInfo, const Nd4jLong *dScalarShapeInfo,
void *extraParams) {
try {
NativeOpExecutioner::execScalar(nullptr,
@ -733,9 +727,9 @@ void execScalar(
void execScalarBool(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalar, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalar, const Nd4jLong *hScalarShapeInfo, const Nd4jLong *dScalarShapeInfo,
void *extraParams) {
try {
NativeOpExecutioner::execScalarBool(nullptr,
@ -768,9 +762,9 @@ void execScalarBool(
*/
void execSummaryStatsScalar(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
bool biasCorrected) {
try {
NativeOpExecutioner::execSummaryStatsScalar(nullptr,
@ -801,9 +795,9 @@ void execSummaryStatsScalar(Nd4jPointer *extraPointers,
*/
void execSummaryStats(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
bool biasCorrected) {
try {
NativeOpExecutioner::execSummaryStats(nullptr,
@ -836,12 +830,12 @@ void execSummaryStats(Nd4jPointer *extraPointers,
*/
void execSummaryStatsTad(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape,
bool biasCorrected,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
@ -882,8 +876,8 @@ void execSummaryStatsTad(Nd4jPointer *extraPointers,
void execTransformFloat(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraParams) {
try {
NativeOpExecutioner::execTransformFloat(nullptr,
@ -908,8 +902,8 @@ void execTransformFloat(
void execTransformSame(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraParams) {
try {
NativeOpExecutioner::execTransformSame(nullptr,
@ -934,8 +928,8 @@ void execTransformSame(
void execTransformBool(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraParams) {
try {
NativeOpExecutioner::execTransformBool(nullptr,
@ -960,8 +954,8 @@ void execTransformBool(
void execTransformAny(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraParams) {
try {
NativeOpExecutioner::execTransformAny(nullptr,
@ -986,8 +980,8 @@ void execTransformAny(
void execTransformStrict(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraParams) {
try {
NativeOpExecutioner::execTransformStrict(nullptr,
@ -1011,19 +1005,17 @@ void execTransformStrict(
void execReduce3All(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParamsVals,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo,
Nd4jLong *yOffsets) {
OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape,
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets) {
try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
NativeOpExecutioner::execReduce3All(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParamsVals, dbY->primary(),
@ -1046,7 +1038,7 @@ void specialConcat(
Nd4jPointer *data,
Nd4jPointer *inputShapeInfo,
void *hZ,
Nd4jLong *hZShapeInfo,
Nd4jLong const* hZShapeInfo,
Nd4jPointer *tadPointers,
Nd4jPointer *offsetPointers) {
try {
@ -1227,7 +1219,7 @@ void setGridLimit(int gridSize) {
// no-op
}
sd::TadPack* tadOnlyShapeInfo(Nd4jLong *hXShapeInfo, int *dimension, int dimensionLength) {
sd::TadPack* tadOnlyShapeInfo(Nd4jLong const* hXShapeInfo, int *dimension, int dimensionLength) {
auto pack = new TadPack();
try {
*pack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
@ -1239,21 +1231,26 @@ sd::TadPack* tadOnlyShapeInfo(Nd4jLong *hXShapeInfo, int *dimension, int dimensi
return pack;
}
Nd4jLong* getPrimaryShapeInfo(sd::TadPack* pack) {
return pack->primaryShapeInfo();
Nd4jLong const* getPrimaryShapeInfo(sd::TadPack* pack) {
return const_cast<Nd4jLong*>(pack->primaryShapeInfo());
}
Nd4jLong* getPrimaryOffsets(sd::TadPack* pack) {
return pack->primaryOffsets();
Nd4jLong const* getPrimaryOffsets(sd::TadPack* pack) {
return const_cast<Nd4jLong*>(pack->primaryOffsets());
}
Nd4jLong* getSpecialShapeInfo(sd::TadPack* pack) {
return pack->specialShapeInfo();
Nd4jLong const* getSpecialShapeInfo(sd::TadPack* pack) {
return const_cast<Nd4jLong*>(pack->specialShapeInfo());
}
Nd4jLong* getSpecialOffsets(sd::TadPack* pack) {
return pack->specialOffsets();
Nd4jLong const* getSpecialOffsets(sd::TadPack* pack) {
return const_cast<Nd4jLong*>(pack->specialOffsets());
}
Nd4jLong getNumberOfTads(sd::TadPack* pack) {
return pack->numberOfTads();
}
int getShapeInfoLength(sd::TadPack* pack) {
return pack->shapeInfoLength();
}
@ -1270,15 +1267,15 @@ Nd4jPointer getConstantSpace() {
template<typename T>
void pullRowsGeneric(void *vx,
Nd4jLong *hXShapeInfo,
Nd4jLong const* hXShapeInfo,
void *vz,
Nd4jLong *hZShapeInfo,
Nd4jLong const* hZShapeInfo,
const int n,
Nd4jLong *indexes,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffsets) {
Nd4jLong const* indexes,
Nd4jLong const* tadShapeInfo,
Nd4jLong const* tadOffsets,
Nd4jLong const* zTadShapeInfo,
Nd4jLong const* zTadOffsets) {
auto hX = reinterpret_cast<T *>(vx);
auto hZ = reinterpret_cast<T *>(vz);
@ -1322,14 +1319,14 @@ void pullRowsGeneric(void *vx,
}
void pullRows(Nd4jPointer *extraPointers,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
Nd4jLong n,
Nd4jLong *indexes,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffsets) {
Nd4jLong* indexes,
Nd4jLong const* tadShapeInfo,
Nd4jLong const* tadOffsets,
Nd4jLong const* zTadShapeInfo,
Nd4jLong const* zTadOffsets) {
try {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -1342,11 +1339,11 @@ void pullRows(Nd4jPointer *extraPointers,
template<typename T>
void tearGeneric(void *vx,
Nd4jLong *hXShapeInfo,
Nd4jLong const* hXShapeInfo,
Nd4jPointer *targets,
Nd4jLong *hZShapeInfo,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets) {
Nd4jLong const* hZShapeInfo,
Nd4jLong const* tadShapeInfo,
Nd4jLong const* tadOffsets) {
auto hX = reinterpret_cast<T *>(vx);
@ -1381,11 +1378,11 @@ void tearGeneric(void *vx,
}
void tear(Nd4jPointer *extraPointers,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
Nd4jPointer *targets,
Nd4jLong *hZShapeInfo,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets) {
Nd4jLong const* hZShapeInfo,
Nd4jLong const* tadShapeInfo,
Nd4jLong const* tadOffsets) {
try {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -1398,10 +1395,10 @@ void tear(Nd4jPointer *extraPointers,
void average(Nd4jPointer *extras,
Nd4jPointer *hX, Nd4jLong *hXShapeInfo,
Nd4jPointer *dX, Nd4jLong *dXShapeInfo,
void *z, Nd4jLong *hZShapeInfo,
void *dz, Nd4jLong *dZShapeInfo,
Nd4jPointer *hX, const Nd4jLong *hXShapeInfo,
Nd4jPointer *dX, const Nd4jLong *dXShapeInfo,
void *z, const Nd4jLong *hZShapeInfo,
void *dz, const Nd4jLong *dZShapeInfo,
int n,
Nd4jLong length,
bool propagate) {
@ -1416,10 +1413,10 @@ void average(Nd4jPointer *extras,
}
void accumulate(Nd4jPointer *extras,
Nd4jPointer *hX, Nd4jLong *hXShapeInfo,
Nd4jPointer *dX, Nd4jLong *dXShapeInfo,
void *hz, Nd4jLong *hZShapeInfo,
void *dz, Nd4jLong *dZShapeInfo,
Nd4jPointer *hX, Nd4jLong const* hXShapeInfo,
Nd4jPointer *dX, Nd4jLong const* dXShapeInfo,
void *hz, Nd4jLong const* hZShapeInfo,
void *dz, Nd4jLong const* dZShapeInfo,
int n,
Nd4jLong length) {
try {
@ -1436,6 +1433,28 @@ void enableP2P(bool enable) {
// no-op
}
void encodeThresholdP1(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, Nd4jLong N, int *dz, float threshold) {
// TODO: to be implemented
}
void encodeThresholdP2Int(Nd4jPointer *extraPointers, int *hX, Nd4jLong N, int *dz) {
// TODO: to be implemented
}
void encodeThresholdP3(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, int *offsets, Nd4jLong N, int *dz){
// offsets won't be used here
// TODO: to be implemented
}
void decodeThreshold(Nd4jPointer *extraPointers, void *hX, Nd4jLong N, void *dz, const Nd4jLong *hZShapeInfo){
// TODO: to be implemented
}
bool isP2PAvailable() {
// always TRUE for cpu backend
return true;
@ -1445,8 +1464,12 @@ void checkP2P() {
// no-op
}
void decodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong N, void *dz, Nd4jLong const* hZShapeInfo) {
NativeOpExecutioner::decodeBitmap(hX, N, dz, hZShapeInfo);
}
template<typename T>
void shuffleGeneric(void **hX, Nd4jLong **hXShapeInfo, void **dz, Nd4jLong **hZShapeInfo, int N, int *shuffleMap, Nd4jLong **tadOnlyShapeInfo, Nd4jLong **tadOffsets) {
void shuffleGeneric(void **hX, Nd4jLong * const*hXShapeInfo, void **dz, Nd4jLong * const* hZShapeInfo, int N, int *shuffleMap, Nd4jLong * const* tadOnlyShapeInfo, Nd4jLong * const* tadOffsets) {
auto dX = reinterpret_cast<T **>(hX);
auto dZ = reinterpret_cast<T **>(dz);
@ -1517,10 +1540,10 @@ void shuffle(Nd4jPointer *extras,
Nd4jPointer *tadShapeInfo,
Nd4jPointer *tadOffsets) {
try {
auto xShape = reinterpret_cast<Nd4jLong **>(hXShapeInfo);
auto zShape = reinterpret_cast<Nd4jLong **>(hZShapeInfo);
auto tadOnlyShapeInfo = reinterpret_cast<Nd4jLong **>(tadShapeInfo);
auto tadOffset = reinterpret_cast<Nd4jLong **>(tadOffsets);
auto xShape = reinterpret_cast<Nd4jLong * const*>(hXShapeInfo);
auto zShape = reinterpret_cast<Nd4jLong * const*>(hZShapeInfo);
auto tadOnlyShapeInfo = reinterpret_cast<Nd4jLong * const*>(tadShapeInfo);
auto tadOffset = reinterpret_cast<Nd4jLong * const*>(tadOffsets);
auto xType = sd::ArrayOptions::dataType(xShape[0]);
@ -1548,13 +1571,13 @@ int getDevice() {
void execScalarTad(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const*dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const*dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
Nd4jLong const*tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const*tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
@ -1588,13 +1611,13 @@ void execScalarTad(Nd4jPointer *extraPointers,
void execScalarBoolTad(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalars, const Nd4jLong *hScalarShapeInfo, const Nd4jLong *dScalarShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) {
try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
@ -1696,7 +1719,7 @@ void execAggregateBatch(Nd4jPointer *extraPointers,
void execRandom(Nd4jPointer *extraPointers,
int opNum,
Nd4jPointer state,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraArguments) {
try {
NativeOpExecutioner::execRandom(nullptr, opNum, state, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments);
@ -1709,9 +1732,9 @@ void execRandom(Nd4jPointer *extraPointers,
void execRandom3(Nd4jPointer *extraPointers,
int opNum,
Nd4jPointer state,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraArguments) {
try {
NativeOpExecutioner::execRandom(nullptr, opNum, state, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments);
@ -1724,8 +1747,8 @@ void execRandom3(Nd4jPointer *extraPointers,
void execRandom2(Nd4jPointer *extraPointers,
int opNum,
Nd4jPointer state,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraArguments) {
try {
NativeOpExecutioner::execRandom(nullptr, opNum, state, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments);
@ -1793,8 +1816,8 @@ Nd4jPointer pointerForAddress(Nd4jLong address) {
}
void sort(Nd4jPointer *extraPointers,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hX, const Nd4jLong *hXShapeInfo,
void *dX, const Nd4jLong *dXShapeInfo,
bool descending) {
try {
NativeOpExecutioner::execSort(hX, hXShapeInfo, descending);
@ -1805,12 +1828,11 @@ void sort(Nd4jPointer *extraPointers,
}
void sortTad(Nd4jPointer *extraPointers,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets,
void *hX, const Nd4jLong *hXShapeInfo,
void *dX, const Nd4jLong *dXShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo,
const Nd4jLong *tadOffsets,
bool descending) {
try {
NativeOpExecutioner::execSort(hX, hXShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending);
@ -1833,6 +1855,12 @@ void sortCooIndices(Nd4jPointer *extraPointers,
}
}
Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, Nd4jLong N, int *dz, float threshold) {
return NativeOpExecutioner::encodeBitmap(hX, hXShapeInfo, N, dz, threshold);
}
Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length) {
auto hZ = new Nd4jLong[2];errno = 0;
try {
@ -1916,7 +1944,7 @@ FORCEINLINE int estimateThresholdGeneric(Nd4jPointer *extraPointers, Nd4jPointer
}
int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer hX, Nd4jLong *hXShapeInfo, int N, float threshold) {
int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer hX, Nd4jLong const* hXShapeInfo, int N, float threshold) {
try {
auto xType = ArrayOptions::dataType(hXShapeInfo);
BUILD_SINGLE_SELECTOR(xType, return estimateThresholdGeneric, (extraPointers, hX, N, threshold), FLOAT_TYPES);
@ -1931,8 +1959,8 @@ Nd4jLong getShapeListSize(sd::ShapeList* list) {
return list->size();
}
Nd4jLong* getShape(sd::ShapeList* list, Nd4jLong i) {
return list->at(i);
Nd4jLong const* getShape(sd::ShapeList* list, Nd4jLong i) {
return const_cast<Nd4jLong const*>(list->at(i));
}
void deleteShapeList(Nd4jPointer shapeList) {
@ -2226,8 +2254,8 @@ const char* getVariableName(sd::graph::Variable* variable) {
return variable->getName()->c_str();
}
Nd4jLong* getVariableShape(sd::graph::Variable* variable) {
return variable->getNDArray()->shapeInfo();
Nd4jLong const* getVariableShape(sd::graph::Variable* variable) {
return const_cast<Nd4jLong const*>(variable->getNDArray()->shapeInfo());
}
void* getVariableBuffer(sd::graph::Variable* variable) {
@ -2569,12 +2597,13 @@ void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
}
template <typename I>
static void _scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs,
void* hX, Nd4jLong* hXShapeInfo, Nd4jLong* hXOffsets,
void* dX, Nd4jLong* dXShapeInfo, Nd4jLong* dXOffsets,
void* hY, Nd4jLong* hYShapeInfo, Nd4jLong* hYOffsets,
void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets,
void* vIindexes, Nd4jLong* hIndicesShapeInfo, void* dIindexes, Nd4jLong* dIndicesShapeInfo) {
static void _scatterUpdate(
Nd4jPointer *extraPointers, int opCode, int numOfSubArrs,
void* hX, const Nd4jLong* hXShapeInfo, const Nd4jLong* hXOffsets,
void* dX, const Nd4jLong* dXShapeInfo, const Nd4jLong* dXOffsets,
void* hY, const Nd4jLong* hYShapeInfo, const Nd4jLong* hYOffsets,
void* dY, const Nd4jLong* dYShapeInfo, const Nd4jLong* dYOffsets,
void* vIindexes, const Nd4jLong* hIndicesShapeInfo, void* dIindexes, const Nd4jLong* dIndicesShapeInfo) {
auto hIindexes = reinterpret_cast<I*>(vIindexes);
auto func = PRAGMA_THREADS_DO {
@ -2626,11 +2655,11 @@ static void _scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSub
////////////////////////////////////////////////////////////////////////
void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs,
void* hX, Nd4jLong* hXShapeInfo, Nd4jLong* hXOffsets,
void* dX, Nd4jLong* dXShapeInfo, Nd4jLong* dXOffsets,
void* hY, Nd4jLong* hYShapeInfo, Nd4jLong* hYOffsets,
void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets,
void* hIindexes, Nd4jLong* hIndicesShapeInfo, void* dIindexes, Nd4jLong* dIndicesShapeInfo) {
void* hX, const Nd4jLong* hXShapeInfo, const Nd4jLong* hXOffsets,
void* dX, const Nd4jLong* dXShapeInfo, const Nd4jLong* dXOffsets,
void* hY, const Nd4jLong* hYShapeInfo, const Nd4jLong* hYOffsets,
void* dY, const Nd4jLong* dYShapeInfo, const Nd4jLong* dYOffsets,
void* hIindexes, const Nd4jLong* hIndicesShapeInfo, void* dIindexes, const Nd4jLong* dIndicesShapeInfo) {
auto iType = ArrayOptions::dataType(hIndicesShapeInfo);
try {
@ -2686,7 +2715,7 @@ void deleteTadPack(sd::TadPack* ptr) {
delete ptr;
}
sd::ConstantDataBuffer* constantBufferLong(sd::DataType dtype, Nd4jLong *data, int length) {
sd::ConstantDataBuffer* constantBufferLong(sd::DataType dtype, const Nd4jLong *data, int length) {
return nullptr;
}
@ -2847,7 +2876,7 @@ Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) {
} else {
shapeBuffer = sd::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
}
return reinterpret_cast<Nd4jPointer>(sd::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true));
return const_cast<Nd4jLong*>(sd::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true));
} catch (std::exception &e) {
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
@ -2856,10 +2885,10 @@ Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) {
}
void sortByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
void *x, const Nd4jLong *xShapeInfo,
void *dx, const Nd4jLong *dxShapeInfo,
void *y, const Nd4jLong *yShapeInfo,
void *dy, const Nd4jLong *dyShapeInfo,
bool descending) {
try {
auto xType = ArrayOptions::dataType(xShapeInfo);
@ -2873,10 +2902,10 @@ void sortByKey(Nd4jPointer *extraPointers,
}
void sortByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
void *x, const Nd4jLong *xShapeInfo,
void *dx, const Nd4jLong *dxShapeInfo,
void *y, const Nd4jLong *yShapeInfo,
void *dy, const Nd4jLong *dyShapeInfo,
bool descending) {
try {
auto xType = ArrayOptions::dataType(xShapeInfo);
@ -2890,12 +2919,11 @@ void sortByValue(Nd4jPointer *extraPointers,
}
void sortTadByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
int *dimension,
int dimensionLength,
void *x, const Nd4jLong *xShapeInfo,
void *dx, const Nd4jLong *dxShapeInfo,
void *y, const Nd4jLong *yShapeInfo,
void *dy, const Nd4jLong *dyShapeInfo,
int *dimension, int dimensionLength,
bool descending) {
try {
auto xType = ArrayOptions::dataType(xShapeInfo);
@ -2909,12 +2937,11 @@ void sortTadByKey(Nd4jPointer *extraPointers,
}
void sortTadByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
int *dimension,
int dimensionLength,
void *x, const Nd4jLong *xShapeInfo,
void *dx, const Nd4jLong *dxShapeInfo,
void *y, const Nd4jLong *yShapeInfo,
void *dy, const Nd4jLong *dyShapeInfo,
int *dimension, int dimensionLength,
bool descending) {
try {
auto xType = ArrayOptions::dataType(xShapeInfo);
@ -3195,8 +3222,8 @@ void dbClose(OpaqueDataBuffer *dataBuffer) {
dataBuffer->getDataBuffer()->close();
}
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void shuffleGeneric, (void**, Nd4jLong**, void**, Nd4jLong**, int, int*, Nd4jLong**, Nd4jLong**), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong const*, void*, Nd4jLong const*, const int, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong const* , Nd4jPointer*, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void shuffleGeneric, (void**, Nd4jLong* const*, void**, Nd4jLong* const*, int, int*, Nd4jLong* const*, Nd4jLong* const*), LIBND4J_TYPES);

View File

@ -87,12 +87,12 @@ extern "C" __global__ void prepareShapeBuffer(int *dimension, int *maxDimension,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execPairwiseTransform(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void const* hY, Nd4jLong const* hYShapeInfo,
void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
void *extraParams) {
auto stream = lc->getCudaStream();
@ -128,12 +128,12 @@ void NativeOpExecutioner::execPairwiseTransform(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execPairwiseBoolTransform( sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void const* hY, Nd4jLong const* hYShapeInfo,
void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
void *extraParams) {
auto stream = lc->getCudaStream();
@ -164,12 +164,12 @@ void NativeOpExecutioner::execPairwiseBoolTransform( sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execPairwiseIntTransform( sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void const* hY, Nd4jLong const* hYShapeInfo,
void const* dY, Nd4jLong const* dYShapeInfo,
void * hZ, Nd4jLong const* hZShapeInfo,
void * dZ, Nd4jLong const* dZShapeInfo,
void *extraParams) {
auto stream = lc->getCudaStream();
@ -200,11 +200,11 @@ void NativeOpExecutioner::execPairwiseIntTransform( sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execSummaryStatsScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
bool biasCorrected) {
auto stream = lc->getCudaStream();
@ -226,16 +226,16 @@ void NativeOpExecutioner::execSummaryStatsScalar(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void const* hY, Nd4jLong const* hYShapeInfo,
void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
auto stream = lc->getCudaStream();
@ -300,16 +300,16 @@ void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext* lc, const int opN
void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void const* hY, Nd4jLong const* hYShapeInfo,
void const* dY, Nd4jLong const* dYShapeInfo,
void* hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
auto stream = lc->getCudaStream();
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -338,15 +338,15 @@ void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void const* hY, Nd4jLong const* hYShapeInfo,
void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const* tadOnlyShapeInfoZ,Nd4jLong const* tadOffsetsZ) {
auto stream = lc->getCudaStream();
@ -413,15 +413,15 @@ void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext* lc, const int opNu
void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void const* hY, Nd4jLong const* hYShapeInfo,
void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const* tadOnlyShapeInfoZ,Nd4jLong const* tadOffsetsZ) {
auto stream = lc->getCudaStream();
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -465,15 +465,15 @@ void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc,
*/
void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void const* hY, Nd4jLong const* hYShapeInfo,
void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const* tadOnlyShapeInfoZ,Nd4jLong const* tadOffsetsZ) {
auto stream = lc->getCudaStream();
@ -536,15 +536,15 @@ void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc, const int opNum,
void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void const* hY, Nd4jLong const* hYShapeInfo,
void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const* tadOnlyShapeInfoZ,Nd4jLong const* tadOffsetsZ) {
auto stream = lc->getCudaStream();
@ -572,13 +572,13 @@ void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer();
@ -607,13 +607,13 @@ void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension,int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer();
@ -643,13 +643,13 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer();
@ -689,13 +689,13 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
*/
void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer();
@ -734,13 +734,13 @@ void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc,
*/
void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension,int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer();
@ -774,11 +774,11 @@ void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo){
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo){
if (sd::Environment::getInstance()->isDebug())
printf("F1 opNum:[%i]\n", opNum);
@ -825,11 +825,11 @@ void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceFloatScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) {
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo) {
auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer();
@ -854,11 +854,11 @@ void NativeOpExecutioner::execReduceFloatScalar(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceBoolScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) {
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo) {
auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer();
@ -885,11 +885,11 @@ void NativeOpExecutioner::execReduceBoolScalar(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceSameScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) {
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo) {
auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer();
@ -916,11 +916,11 @@ void NativeOpExecutioner::execReduceSameScalar(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceLongScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) {
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo) {
auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer();
@ -947,12 +947,12 @@ void NativeOpExecutioner::execReduceLongScalar(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execTransformSame(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
auto stream = lc->getCudaStream();
@ -981,12 +981,12 @@ void NativeOpExecutioner::execTransformSame(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execTransformBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
auto stream = lc->getCudaStream();
@ -1015,12 +1015,12 @@ void NativeOpExecutioner::execTransformBool(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) {
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool allowParallelism) {
auto stream = lc->getCudaStream();
@ -1050,12 +1050,12 @@ void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execTransformStrict(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
auto stream = lc->getCudaStream();
@ -1084,12 +1084,12 @@ void NativeOpExecutioner::execTransformStrict(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execTransformFloat(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer();
@ -1118,11 +1118,11 @@ void NativeOpExecutioner::execTransformFloat(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
bool biasCorrected) {
auto stream = lc->getCudaStream();
@ -1147,13 +1147,13 @@ void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
bool biasCorrected) {
auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer();
@ -1178,13 +1178,13 @@ void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) {
void const* hY, Nd4jLong const* hYShapeInfo,
void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo) {
auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer();
@ -1215,16 +1215,16 @@ void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hY, Nd4jLong const* hYShapeInfo,
void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong* tadOnlyShapeInfo, Nd4jLong* tadOffsets,
Nd4jLong* yTadOnlyShapeInfo, Nd4jLong* yTadOffsets) {
Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) {
if(shape::isScalar(hZShapeInfo)) {
NativeOpExecutioner::execReduce3(lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo);
@ -1268,13 +1268,13 @@ void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduce3Scalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) {
void const* hY, Nd4jLong const* hYShapeInfo,
void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo) {
auto stream = lc->getCudaStream();
@ -1308,12 +1308,12 @@ void NativeOpExecutioner::execReduce3Scalar(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hScalarShapeInfo,
void *dScalar, Nd4jLong *dScalarShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
void const* hScalar, Nd4jLong const* hScalarShapeInfo,
void const* dScalar, Nd4jLong const* dScalarShapeInfo,
void *extraParams, bool allowParallelism) {
auto stream = lc->getCudaStream();
@ -1344,16 +1344,16 @@ void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
void const* hScalars, Nd4jLong const* hScalarShapeInfo,
void const* dScalars, Nd4jLong const* dScalarShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
auto stream = lc->getCudaStream();
@ -1383,12 +1383,12 @@ void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hScalarShapeInfo,
void *dScalar, Nd4jLong *dScalarShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
void const* hScalar, Nd4jLong const* hScalarShapeInfo,
void const* dScalar, Nd4jLong const* dScalarShapeInfo,
void *extraParams, bool allowParallelism) {
auto stream = lc->getCudaStream();
@ -1419,16 +1419,16 @@ void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
void const* hScalars, Nd4jLong const* hScalarShapeInfo,
void const* dScalars, Nd4jLong const* dScalarShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
auto stream = lc->getCudaStream();
@ -1458,12 +1458,12 @@ void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hScalarShapeInfo,
void *dScalar, Nd4jLong *dScalarShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void* hZ, Nd4jLong const* hZShapeInfo,
void* dZ, Nd4jLong const* dZShapeInfo,
void const* hScalar, Nd4jLong const* hScalarShapeInfo,
void const* dScalar, Nd4jLong const* dScalarShapeInfo,
void *extraParams, bool allowParallelism) {
auto stream = lc->getCudaStream();
@ -1493,16 +1493,16 @@ void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
void const* hScalars, Nd4jLong const* hScalarShapeInfo,
void const* dScalars, Nd4jLong const* dScalarShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
auto stream = lc->getCudaStream();
@ -1531,8 +1531,8 @@ void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
int opNum,
Nd4jPointer stateHost,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
void *extraArguments) {
auto stream = lc->getCudaStream();
@ -1564,10 +1564,10 @@ void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
int opNum,
Nd4jPointer stateHost,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
void *extraArguments) {
auto stream = lc->getCudaStream();
@ -1599,12 +1599,12 @@ void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
int opNum,
Nd4jPointer stateHost,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void const* hY, Nd4jLong const* hYShapeInfo,
void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
void *extraArguments) {
auto stream = lc->getCudaStream();
@ -1634,16 +1634,16 @@ void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hY, Nd4jLong const* hYShapeInfo,
void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
Nd4jLong const* xTadShapeInfo, Nd4jLong const* xOffsets,
Nd4jLong const* yTadShapeInfo, Nd4jLong const* yOffsets) {
auto stream = lc->getCudaStream();
auto allocationPointer = lc->getAllocationPointer();
@ -1676,16 +1676,16 @@ void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc,
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduce3TAD(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void const* hX, Nd4jLong const* hXShapeInfo,
void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void const* hY, Nd4jLong const* hYShapeInfo,
void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yTadOffsets) {
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const* yTadShapeInfo, Nd4jLong const* yTadOffsets) {
if(shape::isScalar(hZShapeInfo)) {
NativeOpExecutioner::execReduce3(lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo);

View File

@ -123,8 +123,8 @@ int getDeviceSharedThreshold(int deviceId) {
sd::buffer::Buffer<Nd4jLong> * createScalarBuffer(cudaStream_t stream) {
Nd4jLong *scalarShapeInfo = shape::createScalarShapeInfo();
sd::buffer::Buffer<Nd4jLong> *buff = sd::buffer::createBuffer(scalarShapeInfo,shape::shapeInfoLength(2), stream);
auto scalarShapeInfo = shape::createScalarShapeInfo();
auto buff = sd::buffer::createBuffer(scalarShapeInfo,shape::shapeInfoLength(2), stream);
sd::buffer::copyDataToGpu(&buff, stream);
return buff;
}
@ -229,9 +229,9 @@ public:
void execPairwiseTransform( Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY});
@ -251,9 +251,9 @@ void execPairwiseTransform( Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execPairwiseTransformBool(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY});
@ -275,9 +275,9 @@ void execPairwiseTransformBool(Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execSummaryStatsScalar(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
bool biasCorrected) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -299,11 +299,11 @@ void execSummaryStatsScalar(Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execBroadcastBool(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) {
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -348,10 +348,10 @@ void execBroadcastBool(Nd4jPointer *extraPointers,
void execBroadcast(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) {
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -399,9 +399,9 @@ void execBroadcast(
////////////////////////////////////////////////////////////////////////
void execReduceFloat(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) {
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -421,9 +421,9 @@ void execReduceFloat(Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execReduceSame(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) {
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -443,10 +443,10 @@ void execReduceSame(Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execReduceSame2(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const*hXShapeInfo, Nd4jLong const*dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) {
OpaqueDataBuffer *dbZ, Nd4jLong const*hZShapeInfo, Nd4jLong const*dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong const*hDimensionShape, Nd4jLong const*dDimensionShape) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -476,10 +476,10 @@ void execReduceSame2(Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execReduceLong2(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const*hXShapeInfo, Nd4jLong const*dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) {
OpaqueDataBuffer *dbZ, Nd4jLong const*hZShapeInfo, Nd4jLong const*dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong const*hDimensionShape, Nd4jLong const*dDimensionShape) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -509,9 +509,9 @@ void execReduceLong2(Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execReduceLong(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const*hXShapeInfo, Nd4jLong const*dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) {
OpaqueDataBuffer *dbZ, Nd4jLong const*hZShapeInfo, Nd4jLong const*dZShapeInfo) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -551,10 +551,10 @@ void execReduceLong(Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execReduceBool2(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const*hXShapeInfo, Nd4jLong const*dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) {
OpaqueDataBuffer *dbZ, Nd4jLong const*hZShapeInfo, Nd4jLong const*dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong const*hDimensionShape, Nd4jLong const*dDimensionShape) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -584,9 +584,9 @@ void execReduceBool2(Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execReduceBool(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) {
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -637,10 +637,10 @@ void execReduceBool(Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execIndexReduce(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) {
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -679,10 +679,10 @@ void execIndexReduce(Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execReduceFloat2(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) {
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -720,9 +720,9 @@ void execReduceFloat2(Nd4jPointer *extraPointers,
void execIndexReduceScalar(
Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo){
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo){
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -741,8 +741,8 @@ void execIndexReduceScalar(
////////////////////////////////////////////////////////////////////////
void execTransformSame(Nd4jPointer *extraPointers,int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -766,8 +766,8 @@ void execTransformSame(Nd4jPointer *extraPointers,int opNum,
////////////////////////////////////////////////////////////////////////
void execTransformBool(Nd4jPointer *extraPointers,int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -791,8 +791,8 @@ void execTransformBool(Nd4jPointer *extraPointers,int opNum,
////////////////////////////////////////////////////////////////////////
void execTransformAny(Nd4jPointer *extraPointers,int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -817,8 +817,8 @@ void execTransformAny(Nd4jPointer *extraPointers,int opNum,
////////////////////////////////////////////////////////////////////////
void execTransformStrict(Nd4jPointer *extraPointers,int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -842,8 +842,8 @@ void execTransformStrict(Nd4jPointer *extraPointers,int opNum,
////////////////////////////////////////////////////////////////////////
void execTransformFloat(Nd4jPointer *extraPointers,int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -1368,7 +1368,7 @@ void specialConcat(
Nd4jPointer *data,
Nd4jPointer *inputShapeInfo,
void *dZ,
Nd4jLong *dZShapeInfo, Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers) {
Nd4jLong const* dZShapeInfo, Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers) {
try {
BUILD_SINGLE_SELECTOR(ArrayOptions::dataType(dZShapeInfo), sd::SpecialMethods,
::concatCpuGeneric(dimension, numArrays, data, inputShapeInfo, dZ, dZShapeInfo),
@ -1383,7 +1383,7 @@ void specialConcat(
/**
* This method saves
*/
sd::TadPack* tadOnlyShapeInfo(Nd4jLong *dXShapeInfo, int *dimension, int dimensionLength) {
sd::TadPack* tadOnlyShapeInfo(Nd4jLong const* dXShapeInfo, int *dimension, int dimensionLength) {
try {
auto pack = new TadPack();
*pack = sd::ConstantTadHelper::getInstance()->tadForDimensions(dXShapeInfo, dimension, dimensionLength);
@ -1395,16 +1395,16 @@ sd::TadPack* tadOnlyShapeInfo(Nd4jLong *dXShapeInfo, int *dimension, int dimensi
}
}
Nd4jLong* getPrimaryShapeInfo(sd::TadPack* pack) {
Nd4jLong const* getPrimaryShapeInfo(sd::TadPack* pack) {
return pack->primaryShapeInfo();
}
Nd4jLong* getPrimaryOffsets(sd::TadPack* pack) {
Nd4jLong const* getPrimaryOffsets(sd::TadPack* pack) {
return pack->primaryOffsets();
}
Nd4jLong* getSpecialShapeInfo(sd::TadPack* pack) {
Nd4jLong const* getSpecialShapeInfo(sd::TadPack* pack) {
return pack->specialShapeInfo();
}
Nd4jLong* getSpecialOffsets(sd::TadPack* pack) {
Nd4jLong const* getSpecialOffsets(sd::TadPack* pack) {
return pack->specialOffsets();
}
Nd4jLong getNumberOfTads(sd::TadPack* pack) {
@ -1460,14 +1460,14 @@ Nd4jPointer getConstantSpace() {
}
void pullRows(Nd4jPointer *extraPointers,
OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *zShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* xShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* zShapeInfo, Nd4jLong const* dZShapeInfo,
Nd4jLong n,
Nd4jLong *indexes,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffsets) {
Nd4jLong const* tadShapeInfo,
Nd4jLong const* tadOffsets,
Nd4jLong const* zTadShapeInfo,
Nd4jLong const* zTadOffsets) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -1489,10 +1489,10 @@ void pullRows(Nd4jPointer *extraPointers,
void average(Nd4jPointer *extras,
Nd4jPointer *x, Nd4jLong *xShapeInfo,
Nd4jPointer *dx, Nd4jLong *dXShapeInfo,
void *z, Nd4jLong *zShapeInfo,
void *dz, Nd4jLong *dzShapeInfo,
Nd4jPointer *x, Nd4jLong const* xShapeInfo,
Nd4jPointer *dx, Nd4jLong const* dXShapeInfo,
void *z, Nd4jLong const* zShapeInfo,
void *dz, Nd4jLong const* dzShapeInfo,
int n,
Nd4jLong length,
bool propagate) {
@ -1524,10 +1524,10 @@ void average(Nd4jPointer *extras,
}
void accumulate(Nd4jPointer *extras,
Nd4jPointer *x, Nd4jLong *xShapeInfo,
Nd4jPointer *dx, Nd4jLong *dXShapeInfo,
void *z, Nd4jLong *zShapeInfo,
void *dz, Nd4jLong *dzShapeInfo,
Nd4jPointer *x, Nd4jLong const* xShapeInfo,
Nd4jPointer *dx, Nd4jLong const* dXShapeInfo,
void *z, Nd4jLong const* zShapeInfo,
void *dz, Nd4jLong const* dzShapeInfo,
int n,
Nd4jLong length) {
try {
@ -1572,8 +1572,8 @@ void shuffle(Nd4jPointer *extras,
auto dX = reinterpret_cast<void **>(dx);
auto dZ = reinterpret_cast<void **>(dz);
auto xShape = reinterpret_cast<Nd4jLong **>(xShapeInfo);
auto dxShape = reinterpret_cast<Nd4jLong **>(dXShapeInfo);
auto xShape = reinterpret_cast<Nd4jLong**>(xShapeInfo);
auto dxShape = reinterpret_cast<Nd4jLong**>(dXShapeInfo);
auto tadOnlyShapeInfo = reinterpret_cast<Nd4jLong **>(tadShapeInfo);
auto tadOffset = reinterpret_cast<Nd4jLong **>(tadOffsets);
@ -1614,9 +1614,9 @@ void setTADThreshold(int num) {
////////////////////////////////////////////////////////////////////////
void execSummaryStats(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
bool biasCorrected) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -1638,12 +1638,12 @@ void execSummaryStats(Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execSummaryStatsTad(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
bool biasCorrected,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbDimension});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -1670,10 +1670,10 @@ void execSummaryStatsTad(Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execReduce3(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) {
OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY});
@ -1694,13 +1694,13 @@ void execReduce3(Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execReduce3Tad(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets) {
OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -1744,10 +1744,10 @@ void execReduce3Tad(Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) {
OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY});
@ -1768,9 +1768,9 @@ void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum,
////////////////////////////////////////////////////////////////////////
void execScalarBool(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalar, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbScalar, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo,
void *extraParams) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalar});
@ -1792,13 +1792,13 @@ void execScalarBool(Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execScalarBoolTad(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalars});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -1825,9 +1825,9 @@ void execScalarBoolTad(Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execScalar(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalar, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbScalar, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo,
void *extraParams) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalar});
@ -1849,13 +1849,13 @@ void execScalar(Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execScalarTad(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo,
void *extraParams,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalars});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -1931,7 +1931,7 @@ void execAggregateBatch(Nd4jPointer *extraPointers,
void execRandom(Nd4jPointer *extraPointers,
int opNum,
Nd4jPointer stateHost,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraArguments) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {});
@ -1950,8 +1950,8 @@ void execRandom(Nd4jPointer *extraPointers,
////////////////////////////////////////////////////////////////////////
void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraArguments) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -1971,9 +1971,9 @@ void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost,
////////////////////////////////////////////////////////////////////////
void execRandom3(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraArguments) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY});
@ -2091,11 +2091,11 @@ Nd4jPointer pointerForAddress(Nd4jLong address) {
}
void tear(Nd4jPointer *extras,
OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* xShapeInfo, Nd4jLong const* dXShapeInfo,
Nd4jPointer *targets,
Nd4jLong *zShapeInfo,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets) {
Nd4jLong const* zShapeInfo,
Nd4jLong const* tadShapeInfo,
Nd4jLong const* tadOffsets) {
try {
InteropDataBuffer::prepareSpecialUse({}, {dbX});
@ -2200,13 +2200,13 @@ void prescanArrayRecursive(Nd4jPointer *extras, int *dZ, int *dX, int numElement
////////////////////////////////////////////////////////////////////////
void execReduce3All(Nd4jPointer *extraPointers,
int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParamsVals,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
Nd4jLong const* xTadShapeInfo, Nd4jLong const* xOffsets,
Nd4jLong const* yTadShapeInfo, Nd4jLong const* yOffsets) {
try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY, dbDimension});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -2232,8 +2232,8 @@ void execReduce3All(Nd4jPointer *extraPointers,
void sort(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *x, Nd4jLong const* xShapeInfo,
void *dX, Nd4jLong const* dXShapeInfo,
bool descending) {
try {
cudaStream_t *stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
@ -2298,10 +2298,10 @@ void sort(Nd4jPointer *extraPointers,
void sortByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
void *x, Nd4jLong const* xShapeInfo,
void *dX, Nd4jLong const* dXShapeInfo,
void *y, Nd4jLong const* yShapeInfo,
void *dy, Nd4jLong const* dyShapeInfo,
bool descending) {
try {
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
@ -2372,10 +2372,10 @@ void sortByKey(Nd4jPointer *extraPointers,
}
void sortByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
void *x, Nd4jLong const* xShapeInfo,
void *dX, Nd4jLong const* dXShapeInfo,
void *y, Nd4jLong const* yShapeInfo,
void *dy, Nd4jLong const* dyShapeInfo,
bool descending) {
try {
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
@ -2447,10 +2447,10 @@ void sortByValue(Nd4jPointer *extraPointers,
void sortTadByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
void *x, Nd4jLong const* xShapeInfo,
void *dX, Nd4jLong const* dXShapeInfo,
void *y, Nd4jLong const* yShapeInfo,
void *dy, Nd4jLong const* dyShapeInfo,
int *dimension,
int dimensionLength,
bool descending) {
@ -2474,10 +2474,10 @@ void sortTadByKey(Nd4jPointer *extraPointers,
}
void sortTadByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
void *x, Nd4jLong const* xShapeInfo,
void *dX, Nd4jLong const* dXShapeInfo,
void *y, Nd4jLong const* yShapeInfo,
void *dy, Nd4jLong const* dyShapeInfo,
int *dimension,
int dimensionLength,
bool descending) {
@ -2503,12 +2503,12 @@ void sortTadByValue(Nd4jPointer *extraPointers,
void sortTad(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *x, Nd4jLong const* xShapeInfo,
void *dX, Nd4jLong const* dXShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets,
Nd4jLong const* tadShapeInfo,
Nd4jLong const* tadOffsets,
bool descending) {
try {
// to be implemented
@ -2653,7 +2653,7 @@ Nd4jLong getShapeListSize(sd::ShapeList* list) {
return list->size();
}
Nd4jLong* getShape(sd::ShapeList* list, Nd4jLong i) {
Nd4jLong const* getShape(sd::ShapeList* list, Nd4jLong i) {
return list->at(i);
}
@ -2877,7 +2877,7 @@ const char* getVariableName(sd::graph::Variable* variable) {
return variable->getName()->c_str();
}
Nd4jLong* getVariableShape(sd::graph::Variable* variable) {
Nd4jLong const* getVariableShape(sd::graph::Variable* variable) {
return variable->getNDArray()->shapeInfo();
}
@ -3026,7 +3026,7 @@ void deleteResultWrapper(Nd4jPointer ptr) {
delete p;
}
int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer dX, Nd4jLong *dXShapeInfo, int N, float threshold) {
int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer dX, Nd4jLong const* dXShapeInfo, int N, float threshold) {
throw std::runtime_error("estimateThreshold: Not implemented yet");
}
@ -3237,7 +3237,7 @@ void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
///////////////////////////////////////////////////////////////////
template<typename T, typename I>
__global__ static void scatterUpdateCuda(const int opCode, const int numOfSubArrs,
void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets,
void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong *xOffsets,
void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets,
const void* vindexes) {
@ -3300,7 +3300,7 @@ __global__ static void scatterUpdateCuda(const int opCode, const int numOfSubArr
}
template<typename T, typename I>
__host__ static void scatterUpdateCudaLauncher(const cudaStream_t* stream, const int opCode, const int numOfSubArrs, void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, const void* indexes) {
__host__ static void scatterUpdateCudaLauncher(const cudaStream_t* stream, const int opCode, const int numOfSubArrs, void* vx, const Nd4jLong const* xShapeInfo, const Nd4jLong* xOffsets, void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, const void* indexes) {
scatterUpdateCuda<T, I><<<512, 256, MAX_NUM_THREADS, *stream>>>(opCode, numOfSubArrs, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, indexes);
}
@ -3308,11 +3308,11 @@ __host__ static void scatterUpdateCudaLauncher(const cudaStream_t* stream, const
//////////////////////////////////////////////////////////////////////////
void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs,
void* hX, Nd4jLong* hXShapeInfo, Nd4jLong* hXOffsets,
void* dX, Nd4jLong* dXShapeInfo, Nd4jLong* dXOffsets,
void* hY, Nd4jLong* hYShapeInfo, Nd4jLong* hYOffsets,
void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets,
void* hIindexes, Nd4jLong* hIndicesShapeInfo, void* dIindexes, Nd4jLong* dIndicesShapeInfo) {
void* hX, Nd4jLong const* hXShapeInfo, Nd4jLong const* hXOffsets,
void* dX, Nd4jLong const* dXShapeInfo, Nd4jLong const* dXOffsets,
void* hY, Nd4jLong const* hYShapeInfo, Nd4jLong const* hYOffsets,
void* dY, Nd4jLong const* dYShapeInfo, Nd4jLong const* dYOffsets,
void* hIindexes, Nd4jLong const* hIndicesShapeInfo, void* dIindexes, Nd4jLong const* dIndicesShapeInfo) {
try {
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
@ -3409,7 +3409,7 @@ bool isBlasVersionMatches(int major, int minor, int build) {
return result;
}
sd::ConstantDataBuffer* constantBufferLong(sd::DataType dtype, Nd4jLong *data, int length) {
sd::ConstantDataBuffer* constantBufferLong(sd::DataType dtype, Nd4jLong const* data, int length) {
return sd::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype);
}
@ -3555,8 +3555,7 @@ Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) {
} else {
shapeBuffer = sd::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
}
return reinterpret_cast<Nd4jPointer>(sd::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer,
true));
return (Nd4jPointer)(sd::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true)); // TO DO: this can lead to unpleasant crash sometimes
} catch (std::exception &e) {
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());

View File

@ -21,6 +21,7 @@
#define DEV_TESTS_BROADCASTSCALARCONVERTER_H
#include <system/op_boilerplate.h>
#include <system/op_enums.h>
#include <stdexcept>
namespace sd {

View File

@ -56,18 +56,15 @@ namespace functions {
class Broadcast {
public:
#ifdef __CUDACC__
#ifdef __CUDABLAS__
template<typename OpType>
static __device__ void transformCuda(
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
template<typename OpType>
static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo,
@ -75,67 +72,83 @@ namespace functions {
void *z, const Nd4jLong *zShapeInfo);
template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo);
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream,
int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream,
int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo);
template<typename OpType>
static __device__ void transformInverseCuda(
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __device__ void transformInverseCuda(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
template <typename OpClass>
static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream,
int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
#else
static void execInverse(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
uint64_t start, uint64_t stop);
static void exec(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
sd::LoopKind::Kind loopKind,
uint64_t start,
uint64_t stop);
uint64_t start, uint64_t stop);
/**
* CPU execution
@ -149,39 +162,25 @@ namespace functions {
* @param dimensionLength the length of the dimension buffer
*/
template<typename OpType>
static void exec(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
static void exec(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
sd::LoopKind::Kind loopKind,
uint64_t start,
uint64_t stop);
uint64_t start, uint64_t stop);
template<typename OpType>
static void execInverse(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
static void execInverse(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
uint64_t start, uint64_t stop);
static void exec(const int opNum,
static void exec(int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo);

View File

@ -58,16 +58,13 @@ namespace functions {
#ifdef __CUDACC__
template<typename OpType>
static __device__ void transformCuda(
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
template<typename OpType>
static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo,
@ -76,7 +73,7 @@ namespace functions {
void *extraParams);
template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *result, Nd4jLong const* resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ);
template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream,
@ -85,7 +82,7 @@ namespace functions {
void *z, const Nd4jLong *zShapeInfo,
void *extraParams);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *result, Nd4jLong const* resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum,
const void *x, const Nd4jLong *xShapeInfo,
@ -94,63 +91,61 @@ namespace functions {
void *extraParams);
template<typename OpType>
static __device__ void transformInverseCuda(
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __device__ void transformInverseCuda(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
template <typename OpClass>
static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream,
int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
#else
static void exec(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
uint64_t start, uint64_t stop);
static void exec(const int opNum,
static void exec(int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void *extraParams);
static void execInverse(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
uint64_t start, uint64_t stop);
/**
* CPU execution
@ -164,21 +159,14 @@ namespace functions {
* @param dimensionLength the length of the dimension buffer
*/
template<typename OpType>
static void exec(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
static void exec(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
uint64_t start, uint64_t stop);
template<typename OpType>
static void exec(const void *x, const Nd4jLong *xShapeInfo,
@ -187,21 +175,14 @@ namespace functions {
void *extraParams);
template<typename OpType>
static void execInverse(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
static void execInverse(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
uint64_t start, uint64_t stop);
#endif
};
}

View File

@ -58,15 +58,12 @@ namespace functions {
#ifdef __CUDACC__
template<typename OpType>
static __device__ void transformCuda(
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
template<typename OpType>
static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo,
@ -74,7 +71,13 @@ namespace functions {
void *z, const Nd4jLong *zShapeInfo);
template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream,
@ -82,7 +85,14 @@ namespace functions {
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream,
int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum,
const void *x, const Nd4jLong *xShapeInfo,
@ -90,59 +100,55 @@ namespace functions {
void *z, const Nd4jLong *zShapeInfo);
template<typename OpType>
static __device__ void transformInverseCuda(
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __device__ void transformInverseCuda(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
template <typename OpClass>
static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream,
int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
#else
static void exec(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
uint64_t start, uint64_t stop);
static void exec(const int opNum,
static void exec(int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo);
static void execInverse(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
uint64_t start, uint64_t stop);
/**
* CPU execution
@ -156,20 +162,13 @@ namespace functions {
* @param dimensionLength the length of the dimension buffer
*/
template<typename OpType>
static void exec(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
static void exec(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
uint64_t start, uint64_t stop);
template<typename OpType>
static void exec(const void *x, const Nd4jLong *xShapeInfo,
@ -177,20 +176,13 @@ namespace functions {
void *z, const Nd4jLong *zShapeInfo);
template<typename OpType>
static void execInverse(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
static void execInverse(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
uint64_t start, uint64_t stop);
#endif
};
}

View File

@ -34,20 +34,13 @@ namespace broadcast {
template <typename X, typename Y, typename Z>
void Broadcast<X, Y, Z>::execInverse(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset,
const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
uint64_t start, uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(execInverse, PARAMS(x,
xShapeInfo,
y,
@ -64,21 +57,14 @@ namespace broadcast {
template <typename X, typename Y, typename Z>
void Broadcast<X, Y, Z>::exec(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
sd::LoopKind::Kind loopKind,
uint64_t start,
uint64_t stop) {
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset,
const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
sd::LoopKind::Kind loopKind,
uint64_t start, uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
xShapeInfo,
y,
@ -96,24 +82,17 @@ namespace broadcast {
template <typename X, typename Y, typename Z>
template<typename OpType>
void Broadcast<X, Y, Z>::exec(void *vx,
Nd4jLong *xShapeInfo,
void *vy,
Nd4jLong *yShapeInfo,
void *vz,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
sd::LoopKind::Kind loopKind,
uint64_t start,
uint64_t stop) {
void Broadcast<X, Y, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset,
const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
sd::LoopKind::Kind loopKind,
uint64_t start, uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<Y *>(vy);
auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<const Y *>(vy);
auto z = reinterpret_cast<Z *>(vz);
//decompose in to several sub tads after
@ -397,23 +376,16 @@ namespace broadcast {
template <typename X, typename Y, typename Z>
template<typename OpType>
void Broadcast<X, Y, Z>::execInverse(void *vx,
Nd4jLong *xShapeInfo,
void *vy,
Nd4jLong *yShapeInfo,
void *vz,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *yTadShapeInfo,
Nd4jLong *yTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
void Broadcast<X, Y, Z>::execInverse(const void *vx, const Nd4jLong *xShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffset,
const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
uint64_t start, uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<Y *>(vy);
auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<const Y *>(vy);
auto z = reinterpret_cast<Z *>(vz);
//decompose in to several sub tads after

View File

@ -33,21 +33,14 @@ namespace broadcast {
template <typename X, typename Y>
void BroadcastBool<X, Y>::exec(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset,
const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
uint64_t start, uint64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
xShapeInfo,
y,
@ -75,21 +68,14 @@ namespace broadcast {
template <typename X, typename Y>
void BroadcastBool<X, Y>::execInverse(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset,
const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
uint64_t start, uint64_t stop) {
DISPATCH_BY_OPNUM_TT(execInverse, PARAMS(x,
xShapeInfo,
y,
@ -107,24 +93,17 @@ namespace broadcast {
template <typename X, typename Z>
template<typename OpType>
void BroadcastBool<X, Z>::exec(void *vx,
Nd4jLong *xShapeInfo,
void *vy,
Nd4jLong *yShapeInfo,
void *vz,
Nd4jLong *zShapeInfo,
void *vextraParams,
int *dimension,
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
void BroadcastBool<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
void *vextraParams,
int *dimension, int dimensionLength,
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset,
const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
uint64_t start, uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams);
@ -138,8 +117,8 @@ namespace broadcast {
if (xTadShapeInfo == nullptr || tadOffsets == nullptr) {
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
xTadShapeShapeInfo = tadPack.primaryShapeInfo();
tadOffsets = tadPack.primaryOffsets();
xTadShapeShapeInfo = const_cast<Nd4jLong*>(tadPack.primaryShapeInfo());
tadOffsets = const_cast<Nd4jLong*>(tadPack.primaryOffsets());
}
//int *resultStride = shape::stride(xTadShapeShapeInfo);
@ -279,24 +258,17 @@ namespace broadcast {
template <typename X, typename Z>
template<typename OpType>
void BroadcastBool<X, Z>::execInverse(void *vx,
Nd4jLong *xShapeInfo,
void *vy,
Nd4jLong *yShapeInfo,
void *vz,
Nd4jLong *zShapeInfo,
void *vextraParams,
int *dimension,
int dimensionLength,
Nd4jLong *yTadShapeInfo,
Nd4jLong *yTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
void BroadcastBool<X, Z>::execInverse(const void *vx, const Nd4jLong *xShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
void *vextraParams,
int *dimension, int dimensionLength,
const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffset,
const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
uint64_t start, uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams);
@ -310,8 +282,8 @@ namespace broadcast {
if (yTadShapeInfo == nullptr || tadOffsets == nullptr) {
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dimension, dimensionLength);
yTadShapeShapeInfo = tadPack.primaryShapeInfo();
tadOffsets = tadPack.primaryOffsets();
yTadShapeShapeInfo = const_cast<Nd4jLong*>(tadPack.primaryShapeInfo());
tadOffsets = const_cast<Nd4jLong*>(tadPack.primaryOffsets());
}
//int *resultStride = shape::stride(yTadShapeShapeInfo);

View File

@ -33,20 +33,13 @@ namespace functions {
template <typename X>
void BroadcastInt<X>::exec(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset,
const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
uint64_t start, uint64_t stop) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
xShapeInfo,
y,
@ -72,20 +65,13 @@ namespace functions {
template <typename X>
void BroadcastInt<X>::execInverse(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset,
const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
uint64_t start, uint64_t stop) {
DISPATCH_BY_OPNUM_T(execInverse, PARAMS(x,
xShapeInfo,
y,
@ -102,23 +88,16 @@ namespace functions {
template <typename X>
template<typename OpType>
void BroadcastInt<X>::exec(void *vx,
Nd4jLong *xShapeInfo,
void *vy,
Nd4jLong *yShapeInfo,
void *vz,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
void BroadcastInt<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset,
const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
uint64_t start, uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<X *>(vz);
//decompose in to several sub tads after
@ -131,8 +110,8 @@ namespace functions {
if (xTadShapeInfo == nullptr || tadOffsets == nullptr) {
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
xTadShapeShapeInfo = tadPack.primaryShapeInfo();
tadOffsets = tadPack.primaryOffsets();
xTadShapeShapeInfo = const_cast<Nd4jLong*>(tadPack.primaryShapeInfo());
tadOffsets = const_cast<Nd4jLong*>(tadPack.primaryOffsets());
}
//int *resultStride = shape::stride(xTadShapeShapeInfo);
@ -272,23 +251,16 @@ namespace functions {
template <typename X>
template<typename OpType>
void BroadcastInt<X>::execInverse(void *vx,
Nd4jLong *xShapeInfo,
void *vy,
Nd4jLong *yShapeInfo,
void *vz,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *yTadShapeInfo,
Nd4jLong *yTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
void BroadcastInt<X>::execInverse(const void *vx, const Nd4jLong *xShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
int *dimension, const int dimensionLength,
const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffset,
const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
uint64_t start, uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<X *>(vz);
//decompose in to several sub tads after
@ -301,8 +273,8 @@ namespace functions {
if (yTadShapeInfo == nullptr || tadOffsets == nullptr) {
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dimension, dimensionLength);
yTadShapeShapeInfo = tadPack.primaryShapeInfo();
tadOffsets = tadPack.primaryOffsets();
yTadShapeShapeInfo = const_cast<Nd4jLong*>(tadPack.primaryShapeInfo());
tadOffsets = const_cast<Nd4jLong*>(tadPack.primaryOffsets());
}
//int *resultStride = shape::stride(yTadShapeShapeInfo);

View File

@ -33,27 +33,27 @@ namespace indexreduce {
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
Nd4jLong IndexReduce<X,Y>::execScalar( const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) {
Nd4jLong IndexReduce<X,Y>::execScalar( const int opNum, const void *x, const Nd4jLong *xShapeInfo, void *extraParams) {
RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), INDEX_REDUCE_OPS);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
void IndexReduce<X,Y>::exec(const int opNum,
void *x, Nd4jLong *xShapeInfo,
void *extraParams,
void *z, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams,
void *z, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
template<typename OpType>
Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams) {
Nd4jLong IndexReduce<X, Y>::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams);
//T startingVal = OpType::startingValue(x);
@ -107,13 +107,13 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z>
template<typename OpType>
void IndexReduce<X, Z>::exec(void *vx, Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
void IndexReduce<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -136,7 +136,7 @@ void IndexReduce<X, Z>::exec(void *vx, Nd4jLong *xShapeInfo,
}
auto tadOnlyShapeInfo = tadShapeInfo;
Nd4jLong *tadOffsets = tadOffset;
auto tadOffsets = tadOffset;
if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) {
if (dimensionLength < 1)

View File

@ -34,18 +34,13 @@ namespace functions {
namespace pairwise_transforms {
template <typename X, typename Y, typename Z>
void PairWiseTransform<X, Y, Z>::exec(
const int opNum,
void *x,
Nd4jLong xEws,
void *y,
Nd4jLong yEws,
void *z,
Nd4jLong zEws,
void *extraParams,
Nd4jLong n,
const uint64_t start,
const uint64_t stop) {
void PairWiseTransform<X, Y, Z>::exec(const int opNum,
const void *x, Nd4jLong xEws,
const void *y, Nd4jLong yEws,
void *z, Nd4jLong zEws,
void *extraParams,
Nd4jLong n,
const uint64_t start,const uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
xEws,
y,
@ -60,16 +55,16 @@ namespace functions {
template <typename X, typename Y, typename Z>
template <typename OpType>
void PairWiseTransform<X, Y, Z>::exec(void *vx, Nd4jLong xEws,
void *vy, Nd4jLong yEws,
void PairWiseTransform<X, Y, Z>::exec(const void *vx, Nd4jLong xEws,
const void *vy, Nd4jLong yEws,
void *vz, Nd4jLong zEws,
void *vextraParams,
const Nd4jLong n,
const uint64_t start,
const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<Y *>(vy);
auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<const Y *>(vy);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams);
@ -86,17 +81,12 @@ namespace functions {
}
template <typename X, typename Y, typename Z>
void PairWiseTransform<X, Y, Z>::exec(
const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *extraParams,
const uint64_t start,
const uint64_t stop) {
void PairWiseTransform<X, Y, Z>::exec(const int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void *extraParams,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
xShapeInfo,
y,
@ -110,19 +100,14 @@ namespace functions {
template <typename X, typename Y, typename Z>
template <typename OpType>
void PairWiseTransform<X, Y, Z>::exec(
void *vx,
Nd4jLong* xShapeInfo,
void *vy,
Nd4jLong* yShapeInfo,
void *vz,
Nd4jLong* zShapeInfo,
void *vextraParams,
const uint64_t start,
const uint64_t stop) {
void PairWiseTransform<X, Y, Z>::exec(const void *vx, const Nd4jLong* xShapeInfo,
const void *vy, const Nd4jLong* yShapeInfo,
void *vz, const Nd4jLong* zShapeInfo,
void *vextraParams,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<Y *>(vy);
auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<const Y *>(vy);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams);

View File

@ -30,18 +30,13 @@ namespace functions {
namespace pairwise_transforms {
template <typename X, typename Y>
void PairWiseBoolTransform<X, Y>::exec(
const int opNum,
void *x,
Nd4jLong xEws,
void *y,
Nd4jLong yEws,
void *z,
Nd4jLong zEws,
void *extraParams,
Nd4jLong n,
const uint64_t start,
const uint64_t stop) {
void PairWiseBoolTransform<X, Y>::exec(const int opNum,
const void *x, Nd4jLong xEws,
const void *y, Nd4jLong yEws,
void *z, Nd4jLong zEws,
void *extraParams,
Nd4jLong n,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
xEws,
y,
@ -56,19 +51,15 @@ namespace functions {
template <typename X, typename Z>
template <typename OpType>
void PairWiseBoolTransform<X, Z>::exec(void *vx,
Nd4jLong xEws,
void *vy,
Nd4jLong yEws,
void *vz,
Nd4jLong zEws,
void *vextraParams,
const Nd4jLong n,
const uint64_t start,
const uint64_t stop) {
void PairWiseBoolTransform<X, Z>::exec(const void *vx, Nd4jLong xEws,
const void *vy, Nd4jLong yEws,
void *vz, Nd4jLong zEws,
void *vextraParams,
const Nd4jLong n,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -85,17 +76,12 @@ namespace functions {
}
template <typename X, typename Y>
void PairWiseBoolTransform<X, Y>::exec(
const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *extraParams,
const uint64_t start,
const uint64_t stop) {
void PairWiseBoolTransform<X, Y>::exec(const int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void *extraParams,
const uint64_t start,const uint64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
xShapeInfo,
y,
@ -109,15 +95,14 @@ namespace functions {
template <typename X, typename Z>
template <typename OpType>
void PairWiseBoolTransform<X, Z>::exec(void *vx, Nd4jLong* xShapeInfo,
void *vy, Nd4jLong* yShapeInfo,
void *vz, Nd4jLong* zShapeInfo,
void *vextraParams,
const uint64_t start,
const uint64_t stop) {
void PairWiseBoolTransform<X, Z>::exec(const void *vx, const Nd4jLong* xShapeInfo,
const void *vy, const Nd4jLong* yShapeInfo,
void *vz, const Nd4jLong* zShapeInfo,
void *vextraParams,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams);

View File

@ -30,18 +30,13 @@ namespace functions {
namespace pairwise_transforms {
template <typename X>
void PairWiseIntTransform<X>::exec(
const int opNum,
void *x,
Nd4jLong xEws,
void *y,
Nd4jLong yEws,
void *z,
Nd4jLong zEws,
void *extraParams,
Nd4jLong n,
const uint64_t start,
const uint64_t stop) {
void PairWiseIntTransform<X>::exec(const int opNum,
const void *x, Nd4jLong xEws,
const void *y, Nd4jLong yEws,
void *z, Nd4jLong zEws,
void *extraParams,
Nd4jLong n,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
xEws,
y,
@ -56,19 +51,15 @@ namespace functions {
template <typename X>
template <typename OpType>
void PairWiseIntTransform<X>::exec(void *vx,
Nd4jLong xEws,
void *vy,
Nd4jLong yEws,
void *vz,
Nd4jLong zEws,
void *vextraParams,
const Nd4jLong n,
const uint64_t start,
const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
void PairWiseIntTransform<X>::exec(const void *vx, Nd4jLong xEws,
const void *vy, Nd4jLong yEws,
void *vz, Nd4jLong zEws,
void *vextraParams,
const Nd4jLong n,
const uint64_t start,
const uint64_t stop) {
auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<X *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -85,17 +76,12 @@ namespace functions {
}
template <typename X>
void PairWiseIntTransform<X>::exec(
const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *extraParams,
const uint64_t start,
const uint64_t stop) {
void PairWiseIntTransform<X>::exec(const int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void *extraParams,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
xShapeInfo,
y,
@ -109,15 +95,15 @@ namespace functions {
template <typename X>
template <typename OpType>
void PairWiseIntTransform<X>::exec(void *vx, Nd4jLong* xShapeInfo,
void *vy, Nd4jLong* yShapeInfo,
void *vz, Nd4jLong* zShapeInfo,
void *vextraParams,
const uint64_t start,
const uint64_t stop) {
void PairWiseIntTransform<X>::exec(const void *vx, const Nd4jLong* xShapeInfo,
const void *vy, const Nd4jLong* yShapeInfo,
void *vz, const Nd4jLong* zShapeInfo,
void *vextraParams,
const uint64_t start,
const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<X *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams);

View File

@ -33,16 +33,13 @@ namespace functions {
template<typename X>
template<typename OpClass>
void RandomFunction<X>::execTransform(Nd4jPointer state,
void *vx,
Nd4jLong *xShapeInfo,
void *vy,
Nd4jLong *yShapeInfo,
void *vz,
Nd4jLong *zShapeInfo,
void *vextraArguments) {
const void *vx, const Nd4jLong *xShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
void *vextraArguments) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<X *>(vz);
auto extraArguments = reinterpret_cast<X *>(vextraArguments);
@ -166,12 +163,10 @@ namespace functions {
template<typename X>
template<typename OpClass>
void RandomFunction<X>::execTransform(Nd4jPointer state,
void *vx,
Nd4jLong *xShapeInfo,
void *vz,
Nd4jLong *zShapeInfo,
void *vextraArguments) {
auto x = reinterpret_cast<X *>(vx);
const void *vx, const Nd4jLong *xShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
void *vextraArguments) {
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<X *>(vz);
auto extraArguments = reinterpret_cast<X *>(vextraArguments);
@ -227,7 +222,7 @@ namespace functions {
template<typename X>
template<typename OpClass>
void RandomFunction<X>::execTransform(Nd4jPointer state, void *vz, Nd4jLong *zShapeInfo, void *vextraArguments) {
void RandomFunction<X>::execTransform(Nd4jPointer state, void *vz, const Nd4jLong *zShapeInfo, void *vextraArguments) {
auto z = reinterpret_cast<X *>(vz);
auto extraArguments = reinterpret_cast<X *>(vextraArguments);
@ -266,17 +261,17 @@ namespace functions {
}
template<typename X>
void RandomFunction<X>::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraArguments) {
void RandomFunction<X>::execTransform(int opNum, Nd4jPointer state, const void *x, const Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo, void *extraArguments) {
DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, x, xShapeInfo, z, zShapeInfo, extraArguments), RANDOM_OPS)
}
template<typename X>
void RandomFunction<X>::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraArguments) {
void RandomFunction<X>::execTransform(int opNum, Nd4jPointer state, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, void *extraArguments) {
DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraArguments), RANDOM_OPS)
}
template<typename X>
void RandomFunction<X>::execTransform(int opNum, Nd4jPointer state, void *z, Nd4jLong *zShapeInfo, void *extraArguments) {
void RandomFunction<X>::execTransform(int opNum, Nd4jPointer state, void *z, const Nd4jLong *zShapeInfo, void *extraArguments) {
DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, z, zShapeInfo, extraArguments), RANDOM_OPS)
}

View File

@ -33,12 +33,10 @@ namespace functions {
namespace reduce {
template <typename X, typename Z>
template <typename OpType>
void _CUDA_H ReduceBoolFunction<X,Z>::execScalar(void *vx,
Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz,
Nd4jLong *zShapeInfo) {
auto x = reinterpret_cast<X *>(vx);
void _CUDA_H ReduceBoolFunction<X,Z>::execScalar(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz, const Nd4jLong *zShapeInfo) {
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -78,9 +76,9 @@ namespace functions {
template <typename X, typename Z>
template <typename OpType>
Z _CUDA_H ReduceBoolFunction<X, Z>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams) {
Z _CUDA_H ReduceBoolFunction<X, Z>::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams);
const Nd4jLong length = shape::length(xShapeInfo);
@ -103,49 +101,39 @@ namespace functions {
template <typename X, typename Y>
Y ReduceBoolFunction<X, Y>::execScalar(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams) {
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams) {
RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_BOOL_OPS);
}
template <typename X, typename Y>
void ReduceBoolFunction<X, Y>::execScalar(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *z,
Nd4jLong *zShapeInfo) {
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams,
void *z, const Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_BOOL_OPS);
}
template <typename X, typename Y>
void ReduceBoolFunction<X, Y>::exec(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *z,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams,
void *z, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset, start, stop), REDUCE_BOOL_OPS);
}
template <typename X, typename Z>
template <typename OpType>
void _CUDA_H ReduceBoolFunction<X,Z>::exec(void *vx,
Nd4jLong *xShapeInfo,
void *vextraParams,
void *vresult,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
void _CUDA_H ReduceBoolFunction<X,Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams,
void *vresult, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, int64_t start, int64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vresult);
auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -193,20 +181,17 @@ namespace functions {
template <typename X, typename Z>
template<typename OpType>
void _CUDA_H ReduceBoolFunction<X,Z>::exec(void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *vresult,
Nd4jLong *resultShapeInfo) {
// FIXME: wtf???
void _CUDA_H ReduceBoolFunction<X,Z>::exec(const void *x, const Nd4jLong *xShapeInfo,
void *extraParams,
void *vresult, const Nd4jLong *resultShapeInfo) {
auto z = reinterpret_cast<Z*>(vresult);
z[0] = execScalar<OpType>(x, xShapeInfo, extraParams);
}
template <typename X, typename Z>
template <typename OpType>
Z _CUDA_H ReduceBoolFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx);
Z _CUDA_H ReduceBoolFunction<X, Z>::execScalar(const void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
auto x = reinterpret_cast<const X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
Z intermediate[64];

View File

@ -33,12 +33,10 @@ namespace functions {
namespace reduce {
template <typename X, typename Z>
template <typename OpType>
void _CUDA_H ReduceFloatFunction<X,Z>::execScalar(void *vx,
Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz,
Nd4jLong *zShapeInfo) {
auto x = reinterpret_cast<X *>(vx);
void _CUDA_H ReduceFloatFunction<X,Z>::execScalar(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz, const Nd4jLong *zShapeInfo) {
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams);
@ -98,8 +96,8 @@ namespace functions {
template <typename X, typename Z>
template <typename OpType>
Z _CUDA_H ReduceFloatFunction<X, Z>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx);
Z _CUDA_H ReduceFloatFunction<X, Z>::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) {
auto x = reinterpret_cast<const X *>(vx);
auto extraParams = reinterpret_cast<Z *>(vextraParams);
const Nd4jLong length = shape::length(xShapeInfo);
@ -122,33 +120,27 @@ namespace functions {
template <typename X, typename Y>
Y ReduceFloatFunction<X, Y>::execScalar(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams) {
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams) {
RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_FLOAT_OPS);
}
template <typename X, typename Y>
void ReduceFloatFunction<X, Y>::execScalar(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *z,
Nd4jLong *zShapeInfo) {
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams,
void *z, const Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_FLOAT_OPS);
}
template <typename X, typename Y>
void ReduceFloatFunction<X, Y>::exec(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *z,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams,
void *z, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
xShapeInfo,
extraParams,
@ -163,17 +155,14 @@ namespace functions {
template <typename X, typename Z>
template <typename OpType>
void _CUDA_H ReduceFloatFunction<X,Z>::exec(void *vx,
Nd4jLong *xShapeInfo,
void *vextraParams,
void *vresult,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
void _CUDA_H ReduceFloatFunction<X,Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams,
void *vresult, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
int64_t start, int64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vresult);
auto extraParams = reinterpret_cast<Z *>(vextraParams);
@ -226,11 +215,9 @@ namespace functions {
template <typename X, typename Z>
template<typename OpType>
void _CUDA_H ReduceFloatFunction<X,Z>::exec(void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *vresult,
Nd4jLong *resultShapeInfo) {
void _CUDA_H ReduceFloatFunction<X,Z>::exec(const void *x, const Nd4jLong *xShapeInfo,
void *extraParams,
void *vresult, const Nd4jLong *resultShapeInfo) {
// FIXME: wtf???
auto z = reinterpret_cast<Z*>(vresult);
z[0] = execScalar<OpType>(x, xShapeInfo, extraParams);
@ -238,9 +225,9 @@ namespace functions {
template <typename X, typename Z>
template <typename OpType>
Z _CUDA_H ReduceFloatFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
Z _CUDA_H ReduceFloatFunction<X, Z>::execScalar(const void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto extraParams = reinterpret_cast<Z *>(vextraParams);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
Z intermediate[64];

View File

@ -33,12 +33,10 @@ namespace functions {
namespace reduce {
template <typename X, typename Z>
template <typename OpType>
void _CUDA_H ReduceLongFunction<X,Z>::execScalar(void *vx,
Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz,
Nd4jLong *zShapeInfo) {
auto x = reinterpret_cast<X *>(vx);
void _CUDA_H ReduceLongFunction<X,Z>::execScalar(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz, const Nd4jLong *zShapeInfo) {
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -93,10 +91,8 @@ namespace functions {
template <typename X, typename Z>
template <typename OpType>
Z _CUDA_H ReduceLongFunction<X, Z>::execScalar(void *vx,
Nd4jLong *xShapeInfo,
void *vextraParams) {
auto x = reinterpret_cast<X *>(vx);
Z _CUDA_H ReduceLongFunction<X, Z>::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) {
auto x = reinterpret_cast<const X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams);
const Nd4jLong length = shape::length(xShapeInfo);
@ -120,49 +116,40 @@ namespace functions {
template <typename X, typename Y>
Y ReduceLongFunction<X, Y>::execScalar(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams) {
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams) {
RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_LONG_OPS);
}
template <typename X, typename Y>
void ReduceLongFunction<X, Y>::execScalar(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *z,
Nd4jLong *zShapeInfo) {
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams,
void *z, const Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_LONG_OPS);
}
template <typename X, typename Y>
void ReduceLongFunction<X, Y>::exec(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *z,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams,
void *z, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset, start, stop), REDUCE_LONG_OPS);
}
template <typename X, typename Z>
template <typename OpType>
void _CUDA_H ReduceLongFunction<X,Z>::exec(void *vx,
Nd4jLong *xShapeInfo,
void *vextraParams,
void *vresult,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
void _CUDA_H ReduceLongFunction<X,Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams,
void *vresult, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
int64_t start, int64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vresult);
auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -215,21 +202,18 @@ namespace functions {
template <typename X, typename Z>
template<typename OpType>
void _CUDA_H ReduceLongFunction<X,Z>::exec(void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *vresult,
Nd4jLong *resultShapeInfo) {
// FIXME: wtf???
void _CUDA_H ReduceLongFunction<X,Z>::exec(const void *x, const Nd4jLong *xShapeInfo,
void *extraParams,
void *vresult, const Nd4jLong *resultShapeInfo) {
auto z = reinterpret_cast<Z*>(vresult);
z[0] = execScalar<OpType>(x, xShapeInfo, extraParams);
}
template <typename X, typename Z>
template <typename OpType>
Z _CUDA_H ReduceLongFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
Z _CUDA_H ReduceLongFunction<X, Z>::execScalar(const void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
Z intermediate[64];

View File

@ -34,12 +34,10 @@ namespace functions {
namespace reduce {
template <typename X>
template <typename OpType>
void _CUDA_H ReduceSameFunction<X>::execScalar(void *vx,
Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz,
Nd4jLong *zShapeInfo) {
auto x = reinterpret_cast<X *>(vx);
void _CUDA_H ReduceSameFunction<X>::execScalar(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz, const Nd4jLong *zShapeInfo) {
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<X *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -95,10 +93,8 @@ namespace functions {
template <typename X>
template <typename OpType>
X _CUDA_H ReduceSameFunction<X>::execScalar(void *vx,
Nd4jLong *xShapeInfo,
void *vextraParams) {
auto x = reinterpret_cast<X *>(vx);
X _CUDA_H ReduceSameFunction<X>::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) {
auto x = reinterpret_cast<const X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams);
const Nd4jLong length = shape::length(xShapeInfo);
@ -120,33 +116,27 @@ namespace functions {
template <typename X>
X ReduceSameFunction<X>::execScalar(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams) {
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams) {
RETURNING_DISPATCH_BY_OPNUM_T(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_SAME_OPS);
}
template <typename X>
void ReduceSameFunction<X>::execScalar(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *z,
Nd4jLong *zShapeInfo) {
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams,
void *z, const Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_T(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_SAME_OPS);
}
template <typename X>
void ReduceSameFunction<X>::exec(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *z,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams,
void *z, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
xShapeInfo,
extraParams,
@ -161,17 +151,14 @@ namespace functions {
template <typename X>
template <typename OpType>
void _CUDA_H ReduceSameFunction<X>::exec(void *vx,
Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
void _CUDA_H ReduceSameFunction<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
int64_t start, int64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<X *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -224,21 +211,18 @@ namespace functions {
template <typename X>
template<typename OpType>
void _CUDA_H ReduceSameFunction<X>::exec(void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *vz,
Nd4jLong *zShapeInfo) {
// FIXME: wtf???
void _CUDA_H ReduceSameFunction<X>::exec(const void *x, const Nd4jLong *xShapeInfo,
void *extraParams,
void *vz, const Nd4jLong *zShapeInfo) {
auto z = reinterpret_cast<X*>(vz);
z[0] = execScalar<OpType>(x, xShapeInfo, extraParams);
}
template <typename X>
template <typename OpType>
X _CUDA_H ReduceSameFunction<X>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
X _CUDA_H ReduceSameFunction<X>::execScalar(const void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
X intermediate[64];

View File

@ -34,13 +34,13 @@ namespace reduce3 {
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Z>
template<typename OpType>
void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
void *vextraParams,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo) {
void Reduce3<X,Z>::execScalar(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams);
@ -134,10 +134,10 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
void Reduce3<X,Y>::execScalar(const int opNum,
void *vx, Nd4jLong *xShapeInfo,
void *extraParamsVals,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo) {
const void *vx, const Nd4jLong *xShapeInfo,
void *extraParamsVals,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo), REDUCE3_OPS);
}
@ -146,14 +146,15 @@ void Reduce3<X,Y>::execScalar(const int opNum,
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Z>
template<typename OpType>
void Reduce3<X,Z>::exec(void *vx, Nd4jLong *xShapeInfo,
void *vextraParams,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength, int64_t start, int64_t stop) {
void Reduce3<X,Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
int64_t start, int64_t stop) {
auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<X*>(vy);
auto x = reinterpret_cast<const X*>(vx);
auto y = reinterpret_cast<const X*>(vy);
auto z = reinterpret_cast<Z*>(vz);
auto extraParams = reinterpret_cast<Z*>(vextraParams);
@ -171,15 +172,16 @@ void Reduce3<X,Z>::exec(void *vx, Nd4jLong *xShapeInfo,
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Z>
template<typename OpType>
void Reduce3<X,Z>::exec(void *vx, Nd4jLong *xShapeInfo,
void Reduce3<X,Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, int64_t start, int64_t stop) {
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
int64_t start, int64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams);
#ifdef INLINE_LOOPS
@ -193,16 +195,17 @@ void Reduce3<X,Z>::exec(void *vx, Nd4jLong *xShapeInfo,
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Z>
template<typename OpType>
void Reduce3<X,Z>:: execAll(void *vx, Nd4jLong *xShapeInfo,
void Reduce3<X,Z>:: execAll(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets, int64_t start, int64_t stop) {
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets,
int64_t start, int64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z*>(vextraParams);
@ -215,12 +218,13 @@ void Reduce3<X,Z>:: execAll(void *vx, Nd4jLong *xShapeInfo,
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
void Reduce3<X,Y>::exec( const int opNum,
void *vx, Nd4jLong *xShapeInfo,
void Reduce3<X,Y>::exec(const int opNum,
const void *vx, const Nd4jLong *xShapeInfo,
void *extraParamsVals,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength, int64_t start, int64_t stop) {
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength, start, stop), REDUCE3_OPS);
}
@ -228,13 +232,14 @@ void Reduce3<X,Y>::exec( const int opNum,
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
void Reduce3<X,Y>::exec( const int opNum,
void *vx, Nd4jLong *xShapeInfo,
void Reduce3<X,Y>::exec(const int opNum,
const void *vx, const Nd4jLong *xShapeInfo,
void *extraParamsVals,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, int64_t start, int64_t stop) {
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(vx,xShapeInfo,extraParamsVals,vy, yShapeInfo,vz,zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), REDUCE3_OPS);
}
@ -243,13 +248,14 @@ void Reduce3<X,Y>::exec( const int opNum,
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
void Reduce3<X,Y>::execAll(const int opNum,
void *vx, Nd4jLong *xShapeInfo,
void *extraParamsVals,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets, int64_t start, int64_t stop) {
const void *vx, const Nd4jLong *xShapeInfo,
void *extraParamsVals,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets,
int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(execAll, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, start, stop), REDUCE3_OPS);
}

View File

@ -34,18 +34,18 @@ namespace scalar {
////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z>
template<typename OpType>
void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz, Nd4jLong *zShapeInfo,
void *vscalars,
int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) {
void ScalarTransform<X, Y, Z>::transform(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz, const Nd4jLong *zShapeInfo,
const void *vscalars,
int *dimension, int dimensionLength,
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets,
const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz);
auto scalars = reinterpret_cast<Y *>(vscalars);
auto scalars = reinterpret_cast<const Y *>(vscalars);
auto extraParams = reinterpret_cast<Z *>(vextraParams);
if (zTadShapeInfo == nullptr) {
@ -92,14 +92,14 @@ void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z>
void ScalarTransform<X,Y,Z>::transform(int opNum,
void *x, Nd4jLong *xShapeInfo,
void *extraParams,
void *z, Nd4jLong *zShapeInfo,
void *scalars,
int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) {
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams,
void *z, const Nd4jLong *zShapeInfo,
const void *scalars,
int *dimension, int dimensionLength,
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets,
const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_OPS);
}
@ -107,12 +107,12 @@ void ScalarTransform<X,Y,Z>::transform(int opNum,
////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z>
void ScalarTransform<X, Y, Z>::transform(const int opNum,
void *x, Nd4jLong xStride,
void *z, Nd4jLong zStride,
void *scalar,
void *extraParams,
const uint64_t n,
const uint64_t start, const uint64_t stop) {
const void *x, Nd4jLong xStride,
void *z, Nd4jLong zStride,
const void *scalar,
void *extraParams,
const uint64_t n,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xStride, z, zStride, scalar, extraParams, n, start, stop), SCALAR_OPS);
}
@ -120,11 +120,11 @@ void ScalarTransform<X, Y, Z>::transform(const int opNum,
////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z>
void ScalarTransform<X, Y, Z>::transform(const int opNum,
void *x, Nd4jLong *xShapeInfo,
void *z, Nd4jLong *zShapeInfo,
void *scalar,
void *extraParams,
const uint64_t start, const uint64_t stop) {
const void *x, const Nd4jLong *xShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
const void *scalar,
void *extraParams,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_OPS);
}
@ -132,15 +132,15 @@ void ScalarTransform<X, Y, Z>::transform(const int opNum,
////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z>
template<typename OpType>
void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
void *vscalar,
void *vextraParams,
const uint64_t start, const uint64_t stop) {
void ScalarTransform<X, Y, Z>::transform(const void *vx, const Nd4jLong *xShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
const void *vscalar,
void *vextraParams,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz);
auto scalar = reinterpret_cast<Y *>(vscalar)[0];
auto scalar = reinterpret_cast<const Y *>(vscalar)[0];
auto extraParams = reinterpret_cast<Z *>(vextraParams);
const auto len = shape::length(xShapeInfo);
@ -181,15 +181,15 @@ void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z>
template<typename OpType>
void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong xEws,
void *vz, Nd4jLong zEws,
void *vscalar,
void *vextraParams,
const uint64_t len, const uint64_t start, const uint64_t stop) {
void ScalarTransform<X, Y, Z>::transform(const void *vx, Nd4jLong xEws,
void *vz, Nd4jLong zEws,
const void *vscalar,
void *vextraParams,
const uint64_t len, const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz);
auto scalar = reinterpret_cast<Y *>(vscalar)[0];
auto scalar = reinterpret_cast<const Y *>(vscalar)[0];
auto extraParams = reinterpret_cast<Z *>(vextraParams);
if (xEws == 1 && zEws == 1) {

View File

@ -34,18 +34,18 @@ namespace functions {
template<typename X, typename Z>
template<typename OpType>
void ScalarBoolTransform<X, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz, Nd4jLong *zShapeInfo,
void *vscalars,
int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) {
void ScalarBoolTransform<X, Z>::transform(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz, const Nd4jLong *zShapeInfo,
const void *vscalars,
int *dimension, int dimensionLength,
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets,
const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz);
auto scalars = reinterpret_cast<X *>(vscalars);
auto scalars = reinterpret_cast<const X *>(vscalars);
auto extraParams = reinterpret_cast<X *>(vextraParams);
if (zTadShapeInfo == nullptr) {
@ -92,60 +92,50 @@ namespace functions {
template<typename X, typename Y>
void ScalarBoolTransform<X,Y>::transform(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *z,
Nd4jLong *zShapeInfo,
void *scalars,
int *dimension,
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffsets, const uint64_t start, const uint64_t stop) {
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams,
void *z, const Nd4jLong *zShapeInfo,
const void *scalars,
int *dimension, int dimensionLength,
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets,
const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_BOOL_OPS);
}
template<typename X, typename Y>
void ScalarBoolTransform<X, Y>::transform(const int opNum,
void *x,
Nd4jLong xEws,
void *z,
Nd4jLong zEws,
void *scalar,
void *extraParams,
const uint64_t n,
const uint64_t start, const uint64_t stop) {
const void *x, Nd4jLong xEws,
void *z, Nd4jLong zEws,
const void *scalar,
void *extraParams,
const uint64_t n,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n, start, stop), SCALAR_BOOL_OPS);
}
template<typename X, typename Y>
void ScalarBoolTransform<X, Y>::transform(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *scalar,
void *extraParams,
const uint64_t start, const uint64_t stop) {
const void *x, const Nd4jLong *xShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
const void *scalar,
void *extraParams,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_BOOL_OPS);
}
template<typename X, typename Z>
template<typename OpType>
void ScalarBoolTransform<X, Z>::transform(void *vx,
Nd4jLong *xShapeInfo,
void *vz,
Nd4jLong *zShapeInfo,
void *vscalar,
void *vextraParams,
const uint64_t start, const uint64_t stop) {
void ScalarBoolTransform<X, Z>::transform(const void *vx, const Nd4jLong *xShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
const void *vscalar,
void *vextraParams,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz);
auto scalar = reinterpret_cast<X *>(vscalar)[0];
auto scalar = reinterpret_cast<const X *>(vscalar)[0];
auto extraParams = reinterpret_cast<X *>(vextraParams);
auto xEws = shape::elementWiseStride(xShapeInfo);
@ -185,18 +175,16 @@ namespace functions {
template<typename X, typename Z>
template<typename OpType>
void ScalarBoolTransform<X, Z>::transform(void *vx,
Nd4jLong xEws,
void *vz,
Nd4jLong zEws,
void *vscalar,
void *vextraParams,
const uint64_t len,
const uint64_t start, const uint64_t stop) {
void ScalarBoolTransform<X, Z>::transform(const void *vx, Nd4jLong xEws,
void *vz, Nd4jLong zEws,
const void *vscalar,
void *vextraParams,
const uint64_t len,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz);
auto scalar = reinterpret_cast<X *>(vscalar)[0];
auto scalar = reinterpret_cast<const X *>(vscalar)[0];
auto extraParams = reinterpret_cast<X *>(vextraParams);
if (xEws == 1 && zEws == 1) {

View File

@ -34,18 +34,18 @@ namespace functions {
template<typename X>
template<typename OpType>
void ScalarIntTransform<X>::transform(void *vx, Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz, Nd4jLong *zShapeInfo,
void *vscalars,
int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) {
void ScalarIntTransform<X>::transform(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz, const Nd4jLong *zShapeInfo,
const void *vscalars,
int *dimension, int dimensionLength,
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets,
const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<X *>(vz);
auto scalars = reinterpret_cast<X *>(vscalars);
auto scalars = reinterpret_cast<const X *>(vscalars);
auto extraParams = reinterpret_cast<X *>(vextraParams);
if (zTadShapeInfo == nullptr) {
@ -92,19 +92,14 @@ namespace functions {
template<typename X>
void ScalarIntTransform<X>::transform(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *z,
Nd4jLong *zShapeInfo,
void *scalars,
int *dimension,
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) {
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams,
void *z, const Nd4jLong *zShapeInfo,
const void *scalars,
int *dimension, int dimensionLength,
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets,
const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_INT_OPS);
}
@ -112,42 +107,35 @@ namespace functions {
template<typename X>
void ScalarIntTransform<X>::transform(const int opNum,
void *x,
Nd4jLong xEws,
void *z,
Nd4jLong zEws,
void *scalar,
void *extraParams,
const uint64_t n,
const uint64_t start, const uint64_t stop) {
const void *x, Nd4jLong xEws,
void *z, Nd4jLong zEws,
const void *scalar,
void *extraParams,
const uint64_t n,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n, start, stop), SCALAR_INT_OPS);
}
template<typename X>
void ScalarIntTransform<X>::transform(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *scalar,
void *extraParams,
const uint64_t start, const uint64_t stop) {
const void *x, const Nd4jLong *xShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
const void *scalar,
void *extraParams,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_INT_OPS);
}
template<typename X>
template<typename OpType>
void ScalarIntTransform<X>::transform(void *vx,
Nd4jLong *xShapeInfo,
void *vz,
Nd4jLong *zShapeInfo,
void *vscalar,
void *vextraParams,
const uint64_t start, const uint64_t stop) {
void ScalarIntTransform<X>::transform(const void *vx, const Nd4jLong *xShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
const void *vscalar, void *vextraParams,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<X *>(vz);
auto scalar = reinterpret_cast<X *>(vscalar)[0];
auto scalar = reinterpret_cast<const X *>(vscalar)[0];
auto extraParams = reinterpret_cast<X *>(vextraParams);
auto xEws = shape::elementWiseStride(xShapeInfo);
@ -187,18 +175,15 @@ namespace functions {
template<typename X>
template<typename OpType>
void ScalarIntTransform<X>::transform(void *vx,
Nd4jLong xEws,
void *vz,
Nd4jLong zEws,
void *vscalar,
void *vextraParams,
const uint64_t len,
const uint64_t start, const uint64_t stop) {
void ScalarIntTransform<X>::transform(const void *vx, Nd4jLong xEws,
void *vz, Nd4jLong zEws,
const void *vscalar,
void *vextraParams,
const uint64_t len, const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<X *>(vz);
auto scalar = reinterpret_cast<X *>(vscalar)[0];
auto scalar = reinterpret_cast<const X *>(vscalar)[0];
auto extraParams = reinterpret_cast<X *>(vextraParams);
if (scalar < (sizeof(X) * 8)) {

View File

@ -34,54 +34,46 @@ namespace functions {
template <typename X, typename Y>
Y SummaryStatsReduce<X,Y>::execScalar(const int opNum,
const bool biasCorrected,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams) {
const bool biasCorrected,
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams) {
RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams), SUMMARY_STATS_OPS);
}
template <typename X, typename Y>
void SummaryStatsReduce<X,Y>::execScalar(const int opNum,
const bool biasCorrected,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *z,
Nd4jLong *zShapeInfo) {
const bool biasCorrected,
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams,
void *z, const Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo), SUMMARY_STATS_OPS);
}
template <typename X, typename Y>
void SummaryStatsReduce<X,Y>::exec(const int opNum,
const bool biasCorrected,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *z,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength) {
const bool biasCorrected,
const void *x, const Nd4jLong *xShapeInfo,
void *extraParams,
void *z, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength), SUMMARY_STATS_OPS);
}
template <typename X, typename Z>
template <typename OpType >
void SummaryStatsReduce<X,Z>::execScalar(const bool biasCorrected,
void *vx,
Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz,
Nd4jLong *zShapeInfo) {
const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz, const Nd4jLong *zShapeInfo) {
auto z = reinterpret_cast<Z*>(vz);
z[0] = execScalar<OpType>(biasCorrected, vx, xShapeInfo, vextraParams);
}
template <typename X, typename Z>
template <typename OpType >
Z SummaryStatsReduce<X,Z>::execScalar(const bool biasCorrected, void *vx, Nd4jLong *xShapeInfo, void *vextraParams) {
Z SummaryStatsReduce<X,Z>::execScalar(const bool biasCorrected, const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto extraParams = reinterpret_cast<Z *>(vextraParams);
SummaryStatsData<X> startingIndex;
@ -105,15 +97,12 @@ namespace functions {
template <typename X, typename Z>
template <typename OpType >
void SummaryStatsReduce<X,Z>::exec(const bool biasCorrected,
void *vx,
Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength) {
const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams);
auto resultLength = shape::length(zShapeInfo);

View File

@ -30,25 +30,23 @@ namespace functions {
namespace transform {
template <typename X, typename Y>
void TransformAny<X, Y>::exec(
int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *extraParams,
uint64_t threadId, uint64_t numThreads) {
void TransformAny<X, Y>::exec(int opNum,
const void *x, const Nd4jLong *xShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void *extraParams,
uint64_t threadId, uint64_t numThreads) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_ANY_OPS);
}
/////////////////////////////////////////////////////////////////////
template <typename X, typename Z>
template<typename OpType>
void _CUDA_H TransformAny<X, Z>::exec(void *vx, Nd4jLong *xShapeInfo,
void *vz,Nd4jLong *zShapeInfo,
void *vextraParams, uint64_t threadId, uint64_t numThreads) {
void _CUDA_H TransformAny<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
void *vextraParams,
uint64_t threadId, uint64_t numThreads) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams);

View File

@ -30,27 +30,22 @@ namespace functions {
namespace transform {
template <typename X, typename Y>
void TransformBool<X, Y>::exec(
int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *extraParams,
uint64_t threadId, uint64_t numThreads) {
void TransformBool<X, Y>::exec(int opNum,
const void *x, const Nd4jLong *xShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void *extraParams,
uint64_t threadId, uint64_t numThreads) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_BOOL_OPS);
}
template <typename X, typename Z>
template<typename OpType>
void _CUDA_H TransformBool<X, Z>::exec(
void *vx,
Nd4jLong *xShapeInfo,
void *vz,
Nd4jLong *zShapeInfo,
void *vextraParams, uint64_t threadId, uint64_t numThreads) {
void _CUDA_H TransformBool<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
void *vextraParams,
uint64_t threadId, uint64_t numThreads) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams);

View File

@ -29,27 +29,22 @@ using namespace simdOps;
namespace functions {
namespace transform {
template <typename X, typename Y>
void TransformFloat<X, Y>::exec(
int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *extraParams,
uint64_t threadId, uint64_t numThreads) {
void TransformFloat<X, Y>::exec(int opNum,
const void *x, const Nd4jLong *xShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void *extraParams,
uint64_t threadId, uint64_t numThreads) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_FLOAT_OPS);
}
template <typename X, typename Z>
template<typename OpType>
void _CUDA_H TransformFloat<X, Z>::exec(
void *vx,
Nd4jLong *xShapeInfo,
void *vz,
Nd4jLong *zShapeInfo,
void *vextraParams, uint64_t threadId, uint64_t numThreads) {
void _CUDA_H TransformFloat<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
void *vextraParams,
uint64_t threadId, uint64_t numThreads) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams);

View File

@ -30,24 +30,22 @@ namespace functions {
namespace transform {
template <typename X>
void TransformSame<X>::exec(
int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *extraParams, uint64_t threadId, uint64_t numThreads) {
void TransformSame<X>::exec(int opNum,
const void *x, const Nd4jLong *xShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void *extraParams,
uint64_t threadId, uint64_t numThreads) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_SAME_OPS);
}
template <typename X>
template<typename OpType>
void _CUDA_H TransformSame<X>::exec(void *vx, Nd4jLong *xShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
void _CUDA_H TransformSame<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
void *vextraParams,
uint64_t threadId, uint64_t numThreads) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<X *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams);

View File

@ -30,26 +30,23 @@ namespace functions {
namespace transform {
template <typename X>
void TransformStrict<X>::exec(
int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *extraParams, uint64_t threadId, uint64_t numThreads) {
void TransformStrict<X>::exec(int opNum,
const void *x, const Nd4jLong *xShapeInfo,
void *z,
const Nd4jLong *zShapeInfo,
void *extraParams,
uint64_t threadId, uint64_t numThreads) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_STRICT_OPS);
}
template <typename X>
template<typename OpType>
void _CUDA_H TransformStrict<X>::exec(
void *vx,
Nd4jLong *xShapeInfo,
void *vz,
Nd4jLong *zShapeInfo,
void *vextraParams, uint64_t threadId, uint64_t numThreads) {
void _CUDA_H TransformStrict<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
void *vextraParams,
uint64_t threadId, uint64_t numThreads) {
auto x = reinterpret_cast<X *>(vx);
auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<X *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams);

View File

@ -34,22 +34,22 @@ using namespace simdOps;
template<typename X, typename Y, typename Z, typename OpClass>
static __global__ void broadcastSimple(
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void const* x,
Nd4jLong const* xShapeInfo,
void const* y,
Nd4jLong const* yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
Nd4jLong const* zShapeInfo,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
functions::broadcast::Broadcast<X,Y,Z>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
}
template<typename X, typename Y, typename Z, typename OpClass>
static __global__ void broadcastSimple(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo ) {
static __global__ void broadcastSimple(const void const* x, const Nd4jLong const* xShapeInfo,
const void const* y, const Nd4jLong const* yShapeInfo,
void *z, const Nd4jLong const* zShapeInfo ) {
functions::broadcast::Broadcast<X,Y,Z>::template transformCuda<OpClass>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
}
@ -57,14 +57,14 @@ static __global__ void broadcastSimple(const void *x, const Nd4jLong *xShapeInfo
template<typename X, typename Y, typename Z, typename OpClass>
static __global__ void broadcastInverseSimple(
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void const* x,
Nd4jLong const* xShapeInfo,
void const* y,
Nd4jLong const* yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
Nd4jLong const* zShapeInfo,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
functions::broadcast::Broadcast<X,Y,Z>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
}
@ -73,17 +73,17 @@ static __global__ void broadcastInverseSimple(
namespace functions {
namespace broadcast {
static Nd4jLong __device__ __noinline__ getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) {
static Nd4jLong __device__ __noinline__ getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo) {
return shape::getIndexOffset(index, shapeInfo);
}
static Nd4jLong __device__ __noinline__ length(Nd4jLong *shapeInfo) {
static Nd4jLong __device__ __noinline__ length(const Nd4jLong *shapeInfo) {
return shape::length(shapeInfo);
}
template<typename X, typename Y, typename Z>
template <typename OpClass>
__host__ void Broadcast<X,Y,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
__host__ void Broadcast<X,Y,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
broadcastSimple<X, Y, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
}
@ -94,14 +94,14 @@ namespace functions {
}
template<typename X, typename Y, typename Z>
__host__ void Broadcast<X,Y,Z>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
__host__ void Broadcast<X,Y,Z>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
DISPATCH_BY_OPNUM_TTT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_OPS))
DEBUG_KERNEL(stream, opNum);
}
template<typename X, typename Y, typename Z>
__host__ void Broadcast<X,Y,Z>::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo) {
__host__ void Broadcast<X,Y,Z>::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong const* zShapeInfo) {
DISPATCH_BY_OPNUM_TTT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), OPS_A(BROADCAST_OPS))
DEBUG_KERNEL(stream, opNum);
@ -109,12 +109,12 @@ namespace functions {
template<typename X, typename Y, typename Z>
template <typename OpClass>
__host__ void Broadcast<X,Y,Z>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
__host__ void Broadcast<X,Y,Z>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
broadcastInverseSimple<X, Y, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
}
template<typename X, typename Y, typename Z>
__host__ void Broadcast<X,Y,Z>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
__host__ void Broadcast<X,Y,Z>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
DISPATCH_BY_OPNUM_TTT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_OPS))
DEBUG_KERNEL(stream, opNum);
@ -123,19 +123,19 @@ namespace functions {
template<typename X, typename Y, typename Z>
template <typename OpType>
__device__ void Broadcast<X,Y,Z>::transformInverseCuda(
void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
void const* vx, Nd4jLong const* xShapeInfo,
void const* vy, Nd4jLong const* yShapeInfo,
void* vz, Nd4jLong const* zShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
if (tadOnlyShapeInfoZ == nullptr) {
tadOnlyShapeInfoZ = tadOnlyShapeInfo;
tadOffsetsZ = tadOffsets;
}
auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<Y*>(vy);
auto x = reinterpret_cast<X const*>(vx);
auto y = reinterpret_cast<Y const*>(vy);
auto z = reinterpret_cast<Z*>(vz);
//decompose in to several sub tads after
@ -189,19 +189,19 @@ namespace functions {
template<typename X, typename Y, typename Z>
template <typename OpType>
__device__ void Broadcast<X,Y,Z>::transformCuda(
void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
void const* vx, Nd4jLong const* xShapeInfo,
void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong const* zShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
if (tadOnlyShapeInfoZ == nullptr) {
tadOnlyShapeInfoZ = tadOnlyShapeInfo;
tadOffsetsZ = tadOffsets;
}
auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<Y*>(vy);
auto x = reinterpret_cast<X const*>(vx);
auto y = reinterpret_cast<Y const*>(vy);
auto z = reinterpret_cast<Z*>(vz);
//decompose in to several sub tads after

View File

@ -34,24 +34,24 @@ using namespace simdOps;
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Z, typename OpClass>
static __global__ void broadcastBoolSimple(
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void const* x,
Nd4jLong const* xShapeInfo,
void const* y,
Nd4jLong const* yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
Nd4jLong const* zShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
functions::broadcast::BroadcastBool<X, Z>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo, extraParams, dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Z, typename OpClass>
static __global__ void broadcastBoolSimple(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
static __global__ void broadcastBoolSimple(const void const* x, const Nd4jLong const* xShapeInfo,
const void const* y, const Nd4jLong const* yShapeInfo,
void *z, const Nd4jLong const* zShapeInfo,
void *extraParams) {
functions::broadcast::BroadcastBool<X, Z>::template transformCuda<OpClass>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
@ -59,15 +59,15 @@ static __global__ void broadcastBoolSimple(const void *x, const Nd4jLong *xShape
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Z, typename OpClass>
static __global__ void broadcastBoolInverseSimple(
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void const* x,
Nd4jLong const* xShapeInfo,
void const* y,
Nd4jLong const* yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
Nd4jLong const* zShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
functions::broadcast::BroadcastBool<X, Z>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,extraParams,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
}
@ -78,7 +78,7 @@ namespace broadcast {
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Z>
template <typename OpClass>
__host__ void BroadcastBool<X,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
__host__ void BroadcastBool<X,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
broadcastBoolSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
sd::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed");
}
@ -98,7 +98,7 @@ __host__ void BroadcastBool<X,Z>::intermediateBroadcast(dim3 launchDims, cudaStr
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Y>
__host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
__host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS))
DEBUG_KERNEL(stream, opNum);
@ -119,14 +119,14 @@ __host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *s
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Z>
template <typename OpClass>
__host__ void BroadcastBool<X,Z>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
__host__ void BroadcastBool<X,Z>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
broadcastBoolInverseSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
sd::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed");
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Y>
__host__ void BroadcastBool<X,Y>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
__host__ void BroadcastBool<X,Y>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
DISPATCH_BY_OPNUM_TT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS))
DEBUG_KERNEL(stream, opNum);
@ -136,20 +136,20 @@ __host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *s
template<typename X, typename Z>
template <typename OpType>
__device__ void BroadcastBool<X,Z>::transformInverseCuda(
void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
void const* vx, Nd4jLong const* xShapeInfo,
void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong const* zShapeInfo,
void *vextraParams,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
if (tadOnlyShapeInfoZ == nullptr) {
tadOnlyShapeInfoZ = tadOnlyShapeInfo;
tadOffsetsZ = tadOffsets;
}
auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<X*>(vy);
auto x = reinterpret_cast<X const*>(vx);
auto y = reinterpret_cast<X const*>(vy);
auto z = reinterpret_cast<Z*>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams);
@ -198,20 +198,20 @@ __host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *s
template<typename X, typename Z>
template <typename OpType>
__device__ void BroadcastBool<X,Z>::transformCuda(
void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
void const* vx, Nd4jLong const* xShapeInfo,
void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong const* zShapeInfo,
void *vextraParams,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
if (tadOnlyShapeInfoZ == nullptr) {
tadOnlyShapeInfoZ = tadOnlyShapeInfo;
tadOffsetsZ = tadOffsets;
}
auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<X*>(vy);
auto x = reinterpret_cast<X const*>(vx);
auto y = reinterpret_cast<X const*>(vy);
auto z = reinterpret_cast<Z*>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams);
@ -235,7 +235,7 @@ __host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *s
__syncthreads();
__shared__ Z *rZ;
__shared__ X *rX;
__shared__ X const* rX;
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {

View File

@ -34,23 +34,23 @@ using namespace simdOps;
//////////////////////////////////////////////////////////////////////////
template<typename X, typename OpClass>
static __global__ void broadcastIntSimple(
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void const* x,
Nd4jLong const* xShapeInfo,
void const* y,
Nd4jLong const* yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
Nd4jLong const* zShapeInfo,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
functions::broadcast::BroadcastInt<X>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename OpClass>
static __global__ void broadcastIntSimple(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo) {
static __global__ void broadcastIntSimple(const void *x, const Nd4jLong const* xShapeInfo,
const void *y, const Nd4jLong const* yShapeInfo,
void *z, const Nd4jLong const* zShapeInfo) {
functions::broadcast::BroadcastInt<X>::template transformCuda<OpClass>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
}
@ -58,14 +58,14 @@ static __global__ void broadcastIntSimple(const void *x, const Nd4jLong *xShapeI
//////////////////////////////////////////////////////////////////////////
template<typename X, typename OpClass>
static __global__ void broadcastBoolInverseSimple(
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void const* x,
Nd4jLong const* xShapeInfo,
void const* y,
Nd4jLong const* yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
Nd4jLong const* zShapeInfo,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
functions::broadcast::BroadcastInt<X>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
}
@ -75,7 +75,7 @@ namespace broadcast {
//////////////////////////////////////////////////////////////////////////
template<typename X>
template <typename OpClass>
__host__ void BroadcastInt<X>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
__host__ void BroadcastInt<X>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
broadcastIntSimple<X, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
}
@ -92,16 +92,16 @@ __host__ void BroadcastInt<X>::intermediateBroadcast(dim3 launchDims, cudaStream
//////////////////////////////////////////////////////////////////////////
template<typename X>
__host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
__host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS))
}
//////////////////////////////////////////////////////////////////////////
template<typename X>
__host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo) {
const void *x, const Nd4jLong const* xShapeInfo,
const void *y, const Nd4jLong const* yShapeInfo,
void *z, const Nd4jLong const* zShapeInfo) {
DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), OPS_A(BROADCAST_INT_OPS))
}
@ -109,13 +109,13 @@ __host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stre
//////////////////////////////////////////////////////////////////////////
template<typename X>
template <typename OpClass>
__host__ void BroadcastInt<X>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
__host__ void BroadcastInt<X>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
broadcastBoolInverseSimple<X, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
}
//////////////////////////////////////////////////////////////////////////
template<typename X>
__host__ void BroadcastInt<X>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
__host__ void BroadcastInt<X>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
DISPATCH_BY_OPNUM_T(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS))
}
@ -123,19 +123,19 @@ __host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stre
template<typename X>
template <typename OpType>
__device__ void BroadcastInt<X>::transformInverseCuda(
void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
void const* vx, Nd4jLong const* xShapeInfo,
void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong const* zShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
if (tadOnlyShapeInfoZ == nullptr) {
tadOnlyShapeInfoZ = tadOnlyShapeInfo;
tadOffsetsZ = tadOffsets;
}
auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<X*>(vy);
auto x = reinterpret_cast<X const*>(vx);
auto y = reinterpret_cast<X const*>(vy);
auto z = reinterpret_cast<X*>(vz);
//decompose in to several sub tads after
@ -183,19 +183,19 @@ __host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stre
template<typename X>
template <typename OpType>
__device__ void BroadcastInt<X>::transformCuda(
void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
void const* vx, Nd4jLong const* xShapeInfo,
void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong const* zShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
if (tadOnlyShapeInfoZ == nullptr) {
tadOnlyShapeInfoZ = tadOnlyShapeInfo;
tadOffsetsZ = tadOffsets;
}
auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<X*>(vy);
auto x = reinterpret_cast<X const*>(vx);
auto y = reinterpret_cast<X const*>(vy);
auto z = reinterpret_cast<X*>(vz);
//decompose in to several sub tads after
@ -218,7 +218,7 @@ __host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stre
__syncthreads();
__shared__ X *rZ;
__shared__ X *rX;
__shared__ X const* rX;
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
@ -250,9 +250,9 @@ __host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stre
//////////////////////////////////////////////////////////////////////////
template<typename X>
template <typename OpType>
__device__ void BroadcastInt<X>::transformCuda(const void *vx, const Nd4jLong *xShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo) {
__device__ void BroadcastInt<X>::transformCuda(const void *vx, const Nd4jLong const* xShapeInfo,
const void *vy, const Nd4jLong const* yShapeInfo,
void *vz, const Nd4jLong const* zShapeInfo) {
const X* x = reinterpret_cast<const X*>(vx);
const X* y = reinterpret_cast<const X*>(vy);

View File

@ -31,14 +31,14 @@ using namespace simdOps;
template <typename X, typename Z>
static __global__ void simpleIndexReduceGeneric(const int op,
void *dx,
Nd4jLong *xShapeInfo, int xRank,
void const* dx,
Nd4jLong const* xShapeInfo, int xRank,
void *extraParams,
void *result,
Nd4jLong *zShapeInfo, int zRank,
Nd4jLong const* zShapeInfo, int zRank,
int *dimension,
int dimensionLength,
int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) {
int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) {
functions::indexreduce::IndexReduce<X, Z>::transform(op,dx,xShapeInfo,extraParams,result,zShapeInfo,dimension,dimensionLength,postProcessOrNot,allocationBuffer,reductionBuffer,tadOnlyShapeInfo,tadOffsets);
}
@ -49,15 +49,15 @@ namespace functions {
template <typename X, typename Z>
_CUDA_H void IndexReduce<X,Z>::executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream,
const int opNum,
void *dx, Nd4jLong *xShapeInfo,
void const* dx, Nd4jLong const* xShapeInfo,
int xRank,
void *extraParams,
void *result, Nd4jLong *zShapeInfo,
void *result, Nd4jLong const* zShapeInfo,
int zRank,
int *dimension, int dimensionLength,
int postProcessOrNot,
int *allocationBuffer, void *reductionBuffer,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) {
Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) {
simpleIndexReduceGeneric<X, Z><<<launchDims.x,launchDims.y,launchDims.z, *stream>>>(opNum,
dx, xShapeInfo, xRank,
@ -70,7 +70,7 @@ namespace functions {
}
template <typename X, typename Z>
_CUDA_H void IndexReduce<X, Z>::executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int opNum, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong *zShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) {
_CUDA_H void IndexReduce<X, Z>::executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int opNum, void const* dx, Nd4jLong const* xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong const* zShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) {
simpleIndexReduceGeneric<X, Z><<<launchDims.x,launchDims.y,launchDims.z, *stream>>>(
opNum,
dx,
@ -154,35 +154,35 @@ namespace functions {
template <typename X, typename Y>
__device__ void IndexReduce<X, Y>::transform(
const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void const* x,
Nd4jLong const* xShapeInfo,
void *extraParams,
void *result,
Nd4jLong *zShapeInfo,
Nd4jLong const* zShapeInfo,
int *dimension,
int dimensionLength,
int postProcessOrNot,
int *allocationBuffer,
void *reductionBuffer,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset) {
Nd4jLong const* tadShapeInfo,
Nd4jLong const* tadOffset) {
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, result, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationBuffer, reductionBuffer, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS);
}
template <typename X, typename Z>
template <typename OpType>
__device__ void IndexReduce<X, Z>::transform(void *vdx, Nd4jLong *xShapeInfo,
__device__ void IndexReduce<X, Z>::transform(void const* vdx, Nd4jLong const* xShapeInfo,
void *vextraParams,
void *vz, Nd4jLong *zShapeInfo,
void* vz, Nd4jLong const* zShapeInfo,
int *dimension, int dimensionLength,
int postProcessOrNot,
int *allocationBuffer, void *vreductionBuffer,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets){
Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets){
/**int
* Gpu information for the problem
*/
auto dx = reinterpret_cast<X*>(vdx);
auto dx = reinterpret_cast<X const*>(vdx);
auto z = reinterpret_cast<Z*>(vz);
auto extraParams = static_cast<X*>(vextraParams);
auto reductionBuffer = static_cast<X*>(vreductionBuffer);

View File

@ -28,13 +28,13 @@ using namespace simdOps;
////////////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z, typename OpType>
__global__ static void pairwiseSimpleShaped(void* vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
__global__ static void pairwiseSimpleShaped(void const* vx, Nd4jLong const* xShapeInfo,
void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong const* zShapeInfo,
void *vextraParams) {
auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<Y*>(vy);
auto x = reinterpret_cast<X const*>(vx);
auto y = reinterpret_cast<Y const*>(vy);
auto z = reinterpret_cast<Z*>(vz);
auto extraParams = reinterpret_cast<Z*>(vextraParams);
@ -91,9 +91,9 @@ namespace pairwise_transforms {
template<typename X, typename Y, typename Z>
template<typename OpType>
void __host__ PairWiseTransform<X,Y,Z>::intermediateShaped(dim3& launchDims, cudaStream_t *stream,
void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
void const* vx, Nd4jLong const* xShapeInfo,
void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong const* zShapeInfo,
void *vextraParams){
pairwiseSimpleShaped<X, Y, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams);
@ -101,7 +101,7 @@ void __host__ PairWiseTransform<X,Y,Z>::intermediateShaped(dim3& launchDims, cud
////////////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z>
void __host__ PairWiseTransform<X,Y,Z>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vextraParams) {
void __host__ PairWiseTransform<X,Y,Z>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, void *vz, Nd4jLong const* zShapeInfo, void* vextraParams) {
DISPATCH_BY_OPNUM_TTT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_TRANSFORM_OPS);
}

View File

@ -28,13 +28,13 @@ using namespace simdOps;
////////////////////////////////////////////////////////////////////////////////
template <typename X, typename Z, typename OpType>
__global__ static void pairwiseSimpleShaped(void* vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
__global__ static void pairwiseSimpleShaped(void const* vx, Nd4jLong const* xShapeInfo,
void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong const* zShapeInfo,
void *vextraParams) {
auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<X*>(vy);
auto x = reinterpret_cast<X const*>(vx);
auto y = reinterpret_cast<X const*>(vy);
auto z = reinterpret_cast<Z*>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams);
@ -92,9 +92,9 @@ namespace pairwise_transforms {
template<typename X, typename Z>
template<typename OpType>
void _CUDA_H PairWiseBoolTransform<X,Z>::intermediateShaped(dim3& launchDims, cudaStream_t *stream,
void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
void const* vx, Nd4jLong const* xShapeInfo,
void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong const* zShapeInfo,
void *vextraParams){
pairwiseSimpleShaped<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams);
@ -103,7 +103,7 @@ void _CUDA_H PairWiseBoolTransform<X,Z>::intermediateShaped(dim3& launchDims, cu
////////////////////////////////////////////////////////////////////////////////
template<typename X, typename Y>
void PairWiseBoolTransform<X,Y>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vextraParams) {
void PairWiseBoolTransform<X,Y>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, void *vz, Nd4jLong const* zShapeInfo, void *vextraParams) {
auto xType = sd::DataTypeUtils::fromT<X>();
auto yType = sd::DataTypeUtils::fromT<Y>();

View File

@ -28,13 +28,13 @@ using namespace simdOps;
////////////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType>
__global__ static void pairwiseSimpleShaped(void* vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
__global__ static void pairwiseSimpleShaped(void const* vx, Nd4jLong const* xShapeInfo,
void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong const* zShapeInfo,
void *vextraParams) {
auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<X*>(vy);
auto x = reinterpret_cast<X const*>(vx);
auto y = reinterpret_cast<X const*>(vy);
auto z = reinterpret_cast<X*>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams);
@ -92,9 +92,9 @@ namespace pairwise_transforms {
template<typename X>
template<typename OpType>
void _CUDA_H PairWiseIntTransform<X>::intermediateShaped(dim3& launchDims, cudaStream_t *stream,
void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
void const* vx, Nd4jLong const* xShapeInfo,
void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong const* zShapeInfo,
void *vextraParams){
pairwiseSimpleShaped<X, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams);
@ -103,7 +103,7 @@ void _CUDA_H PairWiseIntTransform<X>::intermediateShaped(dim3& launchDims, cudaS
////////////////////////////////////////////////////////////////////////////////
template<typename X>
void PairWiseIntTransform<X>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vextraParams) {
void PairWiseIntTransform<X>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, void *vz, Nd4jLong const* zShapeInfo, void *vextraParams) {
auto xType = sd::DataTypeUtils::fromT<X>();
DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_INT_OPS);

Some files were not shown because too many files have changed in this diff Show More