Legacy API changes (#441)

* initial commit

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* another initial commit

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* another initial commit

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* one more initial commit

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* next step

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* next step

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* next step

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* next step

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Refactored buffer() and shapeInfo() methods usage with NDArray class.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt Graph class methods to use const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt choose op to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt where op shape method to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt lstsq op to use constant empty shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt matrix_diag_part op shape routine to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt determinant ops to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt mean_pairwssqerr_loss ops to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt ops shape methods.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt shape methods for loss ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt log_loss op shape method.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt shape methods for ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt dilation2d ops shape methods.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted deconv2d ops shape methods.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted dynamicRNN op shape method.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted shape methods for ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted shape methods for lstm layer ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* few updates

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* first cuda tweak

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Adopt constant shapes for sconv2d ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt constant shapes for gru ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt constant shapes with shape methods for segment ops and so on.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted constant shapes with unsorted_segment_* ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted constant shapes with gamma op shape method.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted shape methods of reduce_stddev ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted shape methods for reduce_* ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt shape method for squeeze op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt strided_slice shape method.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored concat op shape method to adopt constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted shape method for mirror_pad op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted split op shape method.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted tile ops shape methods.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added const cast for mkldnn routines handles.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored logSoftMaxForVector_ routine to conform with proper data and shape pointer casts.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Cosmetic changes to proper usage of constant pointers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored a couple shape comparators for strides and addBias helpers to proper use data pointers with inplace option.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored depthToSpace helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored histogram helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored im2col helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored gather and gatherND helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage on percentile helper.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed gather shape with helpers and range buffer usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with space to depth helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage and constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with LUP decomposition>

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored onehot_ helper.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored pad and prefix to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactoed softmax helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed space to batch helpers to use buffers properly.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed stack and split helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with sparse to dense helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with mindistance_ helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with tile helper.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed constant shape usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed constant shape usage with legacy pairwise bool ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored a couple of methods to adopt constant shape usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed broadcasting with constant shape."

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed const usage with inplace reverse and constant shapes with legacy reduction.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored legacy ops with const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored sort to adopt constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected sort for constant shape usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed constant shape usage with special methods.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored Context to conform with constant shape usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* CUDA broadcasting headers

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* pairwise/indexreduce/random headers

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Refactored native ops to adopt constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* legacy reduce3/scalar headers

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Corrected pullRow signature and tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected routines to proper use of constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored tests to use constant shapes properly.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored legacy ops tests to use constant shapes properly.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored buffer usage with NDArray tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed native ops tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed special concat routine.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with test.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with a test.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored TAD.h and tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored calcStrides* routines to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed miscelaneous errors with constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* NativeOps const changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Corrected definitions for declared functions.

Signed-off-by: shugeo <sgazeos@gmail.com>

* NativeOps const changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* few more const changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Fixed const shapes with shape routines.

Signed-off-by: shugeo <sgazeos@gmail.com>

* few more const changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Fixed shape method for broadcastable case.

Signed-off-by: shugeo <sgazeos@gmail.com>

* few more const changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* xw_plus_b BP shape fn restored

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Fixed signatures with broadcasting.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Repaired backprops shape methods for a set of operations.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored broadcast bool for cuda.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored methods for 3 args with const qualifier.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed a couple of kernel signatures for broadcasting.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed kernels signatures for const buffers and shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored pairwise methods to persistent buffers and shapes usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt const to buffers and shapes with kernels.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt const to buffers and shapes with scalar kernels.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored indexreduce kernels signatures to use const buffers and shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored pairwise kernels to adopt cons shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored pairwise bool kernels to adopt cons shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored random special ops to conform with const shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored native ops to conform with const shapes and buffers under cuda platform.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Cosmetical changes only.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed const shapes and buffers error.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected start pos routine.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored methods to conform with const shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored helpers to use proper methods instead.

Signed-off-by: shugeo <sgazeos@gmail.com>

* bunch of changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* next bunch of changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* next bunch of changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Fixed execScalar declaration.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed execScalar declaration.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected const shape cases with sort and so on.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed const shapes for sort.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored kernel declarations to adopt const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed kernels declarations to adopt const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected kernel declarations to adopt const shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed kernels declarations to adopt const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed segment helpers kernels declarations and so on to adopt const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed const shape usage with segment and solve helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed kernel declaration with adjustWeight helper.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed cuda implementations for constant shape helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted const shape usage with kernels.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted top_k kernels to use const shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected kernels declarations to adopt const shapes with helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored NDArray definitions to adopt const shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed const shapes with image suppression helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Slight improvement with buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored buffer usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored buffer usage with tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed const shape usage with definitions.

Signed-off-by: shugeo <sgazeos@gmail.com>

* minor updates on cpu side

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Refactored const shape usage with ConstantDescritor and native ops with cuda platform.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored tear and tile kernels to adopt with const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* softmax_loop fix

Signed-off-by: raver119 <raver119@gmail.com>

* update missing signature

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* softmax again

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* few more missing consts

Signed-off-by: raver119 <raver119@gmail.com>

* new methods updated

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

Co-authored-by: shugeo <sgazeos@gmail.com>
master
raver119 2020-05-09 08:06:14 +03:00 committed by GitHub
parent 0613485654
commit 320924278d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
470 changed files with 7521 additions and 7468 deletions

View File

@ -35,7 +35,7 @@ namespace sd {
std::vector<double> _floatValues; std::vector<double> _floatValues;
public: public:
ConstantDescriptor(double* values, int length); ConstantDescriptor(double* values, int length);
ConstantDescriptor(Nd4jLong* values, int length); ConstantDescriptor(Nd4jLong const* values, int length);
ConstantDescriptor(std::initializer_list<double> values); ConstantDescriptor(std::initializer_list<double> values);
explicit ConstantDescriptor(std::vector<Nd4jLong> &values); explicit ConstantDescriptor(std::vector<Nd4jLong> &values);

View File

@ -125,7 +125,7 @@ namespace sd {
void templatedDoubleAssign(void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const; void templatedDoubleAssign(void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const;
template <typename T, typename R> template <typename T, typename R>
FORCEINLINE R templatedGet(void *buffer, const Nd4jLong index) const; FORCEINLINE R templatedGet(void const* buffer, const Nd4jLong index) const;
/* /*
template <typename T, typename R> template <typename T, typename R>
R templatedGetIndex(void *buffer, Nd4jLong *indices) const; R templatedGetIndex(void *buffer, Nd4jLong *indices) const;
@ -193,7 +193,7 @@ namespace sd {
#ifndef __JAVACPP_HACK__ #ifndef __JAVACPP_HACK__
NDArray(std::shared_ptr<DataBuffer> buffer, const ShapeDescriptor& descriptor, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const Nd4jLong offset = 0); NDArray(std::shared_ptr<DataBuffer> buffer, const ShapeDescriptor& descriptor, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const Nd4jLong offset = 0);
NDArray(std::shared_ptr<DataBuffer> buffer, const char order, const std::vector<Nd4jLong> &shape, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); NDArray(std::shared_ptr<DataBuffer> buffer, char order, const std::vector<Nd4jLong> &shape, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
/** /**
* This contructors create scalar array containing string utf8 * This contructors create scalar array containing string utf8
@ -250,13 +250,14 @@ namespace sd {
/** /**
* do not allocate memory, memory for array is passed from outside * do not allocate memory, memory for array is passed from outside
*/ */
NDArray(void *buffer, Nd4jLong* shapeInfo, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool isBuffAlloc = false); NDArray(void *buffer, Nd4jLong* shapeInfo, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool isBuffAlloc = false);
NDArray(void *buffer, const Nd4jLong* shapeInfo, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool isBuffAlloc = false);
/** /**
* do not allocate memory, memory for array is passed from outside * do not allocate memory, memory for array is passed from outside
* we suppose the content of both (device and host) buffers is identical * we suppose the content of both (device and host) buffers is identical
*/ */
NDArray(void *buffer, void *bufferD, Nd4jLong* shapeInfo, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool isBuffAlloc = false, const bool isBuffDAlloc = false); NDArray(void *buffer, void *bufferD, const Nd4jLong* shapeInfo, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool isBuffAlloc = false, bool isBuffDAlloc = false);
/** /**
* copy constructor * copy constructor
@ -277,28 +278,28 @@ namespace sd {
/** /**
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
*/ */
NDArray(Nd4jLong* shapeInfo, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool nullify = true); NDArray(const Nd4jLong* shapeInfo, bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool nullify = true);
/** /**
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
* set dtype as array type * set dtype as array type
*/ */
NDArray(Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool nullify = true); NDArray(const Nd4jLong* shapeInfo, sd::DataType dtype, bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool nullify = true);
/** /**
* this constructor creates new array using shape information contained in vector argument * this constructor creates new array using shape information contained in vector argument
*/ */
NDArray(const char order, const std::vector<Nd4jLong> &shape, sd::DataType dtype = DOUBLE, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); NDArray(char order, const std::vector<Nd4jLong> &shape, sd::DataType dtype = DOUBLE, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
/** /**
* This constructor creates new array with elements copied from data and using shape information stored in shape, elements from data will be casted to dtype * This constructor creates new array with elements copied from data and using shape information stored in shape, elements from data will be casted to dtype
*/ */
NDArray(const char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, sd::DataType dtype = DOUBLE, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); NDArray(char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, sd::DataType dtype = DOUBLE, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
/** /**
* this constructor creates new array using given buffer (without memory allocation) and shape information stored in shape * this constructor creates new array using given buffer (without memory allocation) and shape information stored in shape
*/ */
NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, sd::DataType dtype, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool isBuffAlloc = false); NDArray(void *buffer, char order, const std::vector<Nd4jLong> &shape, sd::DataType dtype, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
/** /**
* This method returns new array with the same shape & data type * This method returns new array with the same shape & data type
@ -317,12 +318,12 @@ namespace sd {
* this constructor creates new NDArray with shape matching "other" array, * this constructor creates new NDArray with shape matching "other" array,
* doesn't copy "other" elements into new array !!! * doesn't copy "other" elements into new array !!!
*/ */
explicit NDArray(const NDArray* other, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); explicit NDArray(const NDArray* other, bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext ::defaultContext());
/** /**
* this constructor creates scalar(and set its value = 0) or empty array depending on bool argument isScalar * this constructor creates scalar(and set its value = 0) or empty array depending on bool argument isScalar
*/ */
NDArray(sd::DataType dtype, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool isScalar = true); NDArray(sd::DataType dtype, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool isScalar = true);
/** /**
* This method blocks until asynchronous operation finishes * This method blocks until asynchronous operation finishes
@ -364,9 +365,11 @@ namespace sd {
* @param offset * @param offset
* @return * @return
*/ */
void *bufferWithOffset(Nd4jLong offset) const; void const* bufferWithOffset(Nd4jLong offset) const;
void* bufferWithOffset(Nd4jLong offset);
void* specialBufferWithOffset(Nd4jLong offset) const; void const* specialBufferWithOffset(Nd4jLong offset) const;
void* specialBufferWithOffset(Nd4jLong offset);
/** /**
* copy assignment operator * copy assignment operator
* in particular, when _dataType != other._dataType and both shapes are the same, there will be allocation of new _buffer and _dataType acquires other._dataType * in particular, when _dataType != other._dataType and both shapes are the same, there will be allocation of new _buffer and _dataType acquires other._dataType
@ -450,38 +453,39 @@ namespace sd {
/** /**
* returns host buffer * returns host buffer
*/ */
FORCEINLINE void* getBuffer() const;
FORCEINLINE void* buffer(); FORCEINLINE void* buffer();
FORCEINLINE const void* buffer() const;
/** /**
* returns buffer offset (offset is the same for host and device buffers) * returns buffer offset (offset is the same for host and device buffers)
*/ */
FORCEINLINE Nd4jLong getBufferOffset() const; FORCEINLINE Nd4jLong bufferOffset() const;
FORCEINLINE Nd4jLong bufferOffset();
/** /**
* if _bufferD==nullptr return _buffer, else return _bufferD * if _bufferD==nullptr return _buffer, else return _bufferD
*/ */
void* specialBuffer(); void* specialBuffer();
void* getSpecialBuffer() const; const void* specialBuffer() const;
/** /**
* returns device buffer if compilation is for cuda case, otherwise returns host buffer * returns device buffer if compilation is for cuda case, otherwise returns host buffer
*/ */
void* getPlatformBuffer() const;
void* platformBuffer(); void* platformBuffer();
const void* platformBuffer() const;
template <typename T> template <typename T>
T* bufferAsT() const; T* bufferAsT();
template <typename T>
const T* bufferAsT() const;
/** /**
* returns _shapeInfo * returns _shapeInfo
*/ */
FORCEINLINE Nd4jLong* shapeInfo(); FORCEINLINE const Nd4jLong* shapeInfo() const;
FORCEINLINE Nd4jLong* getShapeInfo() const;
/** /**
@ -493,12 +497,9 @@ namespace sd {
/** /**
* if _shapeInfoD==nullptr return _shapeInfo, else return _shapeInfoD * if _shapeInfoD==nullptr return _shapeInfo, else return _shapeInfoD
*/ */
FORCEINLINE Nd4jLong* specialShapeInfo(); FORCEINLINE const Nd4jLong* specialShapeInfo() const;
FORCEINLINE Nd4jLong* getSpecialShapeInfo() const;
const Nd4jLong* platformShapeInfo() const;
Nd4jLong* platformShapeInfo();
Nd4jLong* getPlatformShapeInfo() const;
/** /**
* permutes (in-place) the dimensions in array according to "dimensions" array * permutes (in-place) the dimensions in array according to "dimensions" array
@ -1509,8 +1510,8 @@ bool NDArray::isAttached() {
} }
template <typename T, typename R> template <typename T, typename R>
FORCEINLINE R NDArray::templatedGet(void *buffer, Nd4jLong index) const { FORCEINLINE R NDArray::templatedGet(void const* buffer, Nd4jLong index) const {
auto b = reinterpret_cast<T*>(buffer); auto b = reinterpret_cast<T const*>(buffer);
auto v = static_cast<R>(b[index]); auto v = static_cast<R>(b[index]);
return v; return v;
} }
@ -1625,9 +1626,9 @@ bool NDArray::nonNull() const {
return true; return true;
if(!Environment::getInstance()->isCPU()) if(!Environment::getInstance()->isCPU())
return getDataBuffer()->special() != nullptr && getSpecialShapeInfo() != nullptr; return getDataBuffer()->special() != nullptr && specialShapeInfo() != nullptr;
return getDataBuffer()->primary() != nullptr && getShapeInfo() != nullptr; return getDataBuffer()->primary() != nullptr && shapeInfo() != nullptr;
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -1744,7 +1745,7 @@ bool NDArray::isEmpty() const {
if (this->_shapeInfo == nullptr) if (this->_shapeInfo == nullptr)
return false; return false;
return ArrayOptions::arrayType(this->getShapeInfo()) == ArrayType::EMPTY; return ArrayOptions::arrayType(this->shapeInfo()) == ArrayType::EMPTY;
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -1804,7 +1805,7 @@ T& NDArray::t(const Nd4jLong i, const Nd4jLong j) {
syncToHost(); syncToHost();
Nd4jLong coords[2] = {i, j}; Nd4jLong coords[2] = {i, j};
auto offset = shape::getOffset(getShapeInfo(), coords); auto offset = shape::getOffset(shapeInfo(), coords);
tickWriteHost(); tickWriteHost();
return *(reinterpret_cast<T*>(bufferWithOffset(offset))); return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
} }
@ -1821,7 +1822,7 @@ T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) {
syncToHost(); syncToHost();
Nd4jLong coords[3] = {i, j, k}; Nd4jLong coords[3] = {i, j, k};
auto offset = shape::getOffset(getShapeInfo(), coords); auto offset = shape::getOffset(shapeInfo(), coords);
tickWriteHost(); tickWriteHost();
return *(reinterpret_cast<T*>(bufferWithOffset(offset))); return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
} }
@ -1838,7 +1839,7 @@ T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLo
syncToHost(); syncToHost();
Nd4jLong coords[4] = {i, j, k, w}; Nd4jLong coords[4] = {i, j, k, w};
auto offset = shape::getOffset(getShapeInfo(), coords); auto offset = shape::getOffset(shapeInfo(), coords);
tickWriteHost(); tickWriteHost();
return *(reinterpret_cast<T*>(bufferWithOffset(offset))); return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
} }
@ -1856,7 +1857,7 @@ T NDArray::t(const Nd4jLong i) const {
syncToHost(); syncToHost();
tickReadHost(); tickReadHost();
return *(reinterpret_cast<T*>(bufferWithOffset(getOffset(i)))); return *(reinterpret_cast<const T*>(bufferWithOffset(getOffset(i))));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1872,9 +1873,9 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const {
syncToHost(); syncToHost();
Nd4jLong coords[2] = {i, j}; Nd4jLong coords[2] = {i, j};
auto offset = shape::getOffset(getShapeInfo(), coords); auto offset = shape::getOffset(shapeInfo(), coords);
tickReadHost(); tickReadHost();
return *(reinterpret_cast<T*>(bufferWithOffset(offset))); return *(reinterpret_cast<const T*>(bufferWithOffset(offset)));
} }
template <typename T> template <typename T>
@ -1889,9 +1890,9 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const {
syncToHost(); syncToHost();
Nd4jLong coords[3] = {i, j, k}; Nd4jLong coords[3] = {i, j, k};
auto offset = shape::getOffset(getShapeInfo(), coords); auto offset = shape::getOffset(shapeInfo(), coords);
tickReadHost(); tickReadHost();
return *(reinterpret_cast<T*>(bufferWithOffset(offset))); return *(reinterpret_cast<const T*>(bufferWithOffset(offset)));
} }
template <typename T> template <typename T>
@ -1906,9 +1907,9 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const {
syncToHost(); syncToHost();
Nd4jLong coords[4] = {i, j, k, w}; Nd4jLong coords[4] = {i, j, k, w};
auto offset = shape::getOffset(getShapeInfo(), coords); auto offset = shape::getOffset(shapeInfo(), coords);
tickReadHost(); tickReadHost();
return *(reinterpret_cast<T*>(bufferWithOffset(offset))); return *(reinterpret_cast<const T*>(bufferWithOffset(offset)));
} }
#ifndef __JAVACPP_HACK__ #ifndef __JAVACPP_HACK__
@ -1924,8 +1925,7 @@ std::shared_ptr<DataBuffer> NDArray::dataBuffer() {
#endif #endif
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void* NDArray::getBuffer() const { const void* NDArray::buffer() const {
return _buffer->primary() != nullptr ? static_cast<int8_t*>(_buffer->primary()) + (_offset * sizeOfT()) : nullptr; return _buffer->primary() != nullptr ? static_cast<int8_t*>(_buffer->primary()) + (_offset * sizeOfT()) : nullptr;
} }
@ -1934,18 +1934,13 @@ void* NDArray::buffer() {
return _buffer->primary() != nullptr ? static_cast<int8_t*>(_buffer->primary()) + (_offset * sizeOfT()) : nullptr; return _buffer->primary() != nullptr ? static_cast<int8_t*>(_buffer->primary()) + (_offset * sizeOfT()) : nullptr;
} }
////////////////////////////////////////////////////////////////////////
Nd4jLong* NDArray::getShapeInfo() const {
return _shapeInfo;
}
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
Nd4jLong* NDArray::shapeInfo() { const Nd4jLong* NDArray::shapeInfo() const {
return _shapeInfo; return _shapeInfo;
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
Nd4jLong* NDArray::specialShapeInfo() { const Nd4jLong* NDArray::specialShapeInfo() const {
if (_shapeInfoD == nullptr) if (_shapeInfoD == nullptr)
return _shapeInfo; return _shapeInfo;
// FIXME: this should be fixed once CUDA backend added // FIXME: this should be fixed once CUDA backend added
@ -1953,23 +1948,10 @@ Nd4jLong* NDArray::specialShapeInfo() {
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
Nd4jLong NDArray::getBufferOffset() const { Nd4jLong NDArray::bufferOffset() const {
return _offset; return _offset;
} }
////////////////////////////////////////////////////////////////////////
Nd4jLong NDArray::bufferOffset() {
return _offset;
}
////////////////////////////////////////////////////////////////////////
Nd4jLong* NDArray::getSpecialShapeInfo() const{
if (_shapeInfoD == nullptr)
return _shapeInfo;
// FIXME: this should be fixed once CUDA backend added
return _shapeInfoD;
}
#if defined(__CUDACC__) //&& defined(BUILD_TESTS) #if defined(__CUDACC__) //&& defined(BUILD_TESTS)
// for CUDA we need stil stuff inline // for CUDA we need stil stuff inline

File diff suppressed because it is too large Load Diff

View File

@ -23,26 +23,26 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
static Nd4jLong __device__ __noinline__ getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) { static Nd4jLong __device__ __noinline__ getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo) {
return shape::getIndexOffset(index, shapeInfo); return shape::getIndexOffset(index, shapeInfo);
} }
static Nd4jLong __device__ __noinline__ length(Nd4jLong *shapeInfo) { static Nd4jLong __device__ __noinline__ length(const Nd4jLong *shapeInfo) {
return shape::length(shapeInfo); return shape::length(shapeInfo);
} }
template <typename T, typename Lambda> static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda); template <typename T, typename Lambda> static _CUDA_G void lambdaKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda);
template <typename T, typename Lambda> static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda); template <typename T, typename Lambda> static _CUDA_G void lambdaIndexedKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda);
template <typename T, typename Lambda> static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda); template <typename T, typename Lambda> static _CUDA_G void lambdaIndexedPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda);
template <typename T, typename Lambda> static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda); template <typename T, typename Lambda> static _CUDA_G void lambdaPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda);
template <typename T, typename Lambda> static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda); template <typename T, typename Lambda> static _CUDA_G void lambdaTriplewiseKernel(const void* vw, const Nd4jLong *wShapeInfo, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda);
template <typename T> template <typename T>
class LambdaHelper { class LambdaHelper {
public: public:
template <typename Lambda> template <typename Lambda>
FORCEINLINE static void lambdaLauncher(cudaStream_t *stream, void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) { FORCEINLINE static void lambdaLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
lambdaKernel<T, Lambda><<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, lambda); lambdaKernel<T, Lambda><<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, lambda);
auto err = cudaStreamSynchronize(*stream); auto err = cudaStreamSynchronize(*stream);
if (err != 0) if (err != 0)
@ -50,7 +50,7 @@ public:
} }
template <typename Lambda> template <typename Lambda>
FORCEINLINE static void lambdaIndexedLauncher(cudaStream_t *stream, void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) { FORCEINLINE static void lambdaIndexedLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
lambdaIndexedKernel<T, Lambda><<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, lambda); lambdaIndexedKernel<T, Lambda><<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, lambda);
auto err = cudaStreamSynchronize(*stream); auto err = cudaStreamSynchronize(*stream);
if (err != 0) if (err != 0)
@ -58,7 +58,7 @@ public:
} }
template <typename Lambda> template <typename Lambda>
FORCEINLINE static void lambdaPairwiseLauncher(cudaStream_t *stream, void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) { FORCEINLINE static void lambdaPairwiseLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
lambdaPairwiseKernel<T, Lambda><<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda); lambdaPairwiseKernel<T, Lambda><<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda);
auto err = cudaStreamSynchronize(*stream); auto err = cudaStreamSynchronize(*stream);
if (err != 0) if (err != 0)
@ -66,7 +66,7 @@ public:
} }
template <typename Lambda> template <typename Lambda>
FORCEINLINE static void lambdaIndexedPairwiseLauncher(cudaStream_t *stream, void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) { FORCEINLINE static void lambdaIndexedPairwiseLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
lambdaIndexedPairwiseKernel<T, Lambda><<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda); lambdaIndexedPairwiseKernel<T, Lambda><<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda);
auto err = cudaStreamSynchronize(*stream); auto err = cudaStreamSynchronize(*stream);
if (err != 0) if (err != 0)
@ -74,7 +74,7 @@ public:
} }
template <typename Lambda> template <typename Lambda>
FORCEINLINE static void lambdaTriplewiseLauncher(cudaStream_t *stream, void* vw, Nd4jLong *wShapeInfo, void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) { FORCEINLINE static void lambdaTriplewiseLauncher(cudaStream_t *stream,const void* vw, const Nd4jLong *wShapeInfo, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
lambdaTriplewiseKernel<T, Lambda><<<256, 512, 1024, *stream>>>(vw, wShapeInfo, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda); lambdaTriplewiseKernel<T, Lambda><<<256, 512, 1024, *stream>>>(vw, wShapeInfo, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda);
auto err = cudaStreamSynchronize(*stream); auto err = cudaStreamSynchronize(*stream);
if (err != 0) if (err != 0)
@ -84,8 +84,8 @@ public:
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename T, typename Lambda> template <typename T, typename Lambda>
static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) { static _CUDA_G void lambdaKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
auto x = reinterpret_cast<T*>(vx); auto x = reinterpret_cast<const T*>(vx);
auto z = reinterpret_cast<T*>(vz); auto z = reinterpret_cast<T*>(vz);
auto xEws = shape::elementWiseStride(xShapeInfo); auto xEws = shape::elementWiseStride(xShapeInfo);
@ -113,8 +113,8 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename T, typename Lambda> template <typename T, typename Lambda>
static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) { static _CUDA_G void lambdaIndexedKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
auto x = reinterpret_cast<T*>(vx); auto x = reinterpret_cast<const T*>(vx);
auto z = reinterpret_cast<T*>(vz); auto z = reinterpret_cast<T*>(vz);
auto xEws = shape::elementWiseStride(xShapeInfo); auto xEws = shape::elementWiseStride(xShapeInfo);
@ -142,9 +142,9 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename T, typename Lambda> template <typename T, typename Lambda>
static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) { static _CUDA_G void lambdaIndexedPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
auto x = reinterpret_cast<T*>(vx); auto x = reinterpret_cast<const T*>(vx);
auto y = reinterpret_cast<T*>(vy); auto y = reinterpret_cast<const T*>(vy);
auto z = reinterpret_cast<T*>(vz); auto z = reinterpret_cast<T*>(vz);
auto xEws = shape::elementWiseStride(xShapeInfo); auto xEws = shape::elementWiseStride(xShapeInfo);
@ -175,9 +175,9 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename T, typename Lambda> template <typename T, typename Lambda>
static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) { static _CUDA_G void lambdaPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
auto x = reinterpret_cast<T*>(vx); auto x = reinterpret_cast<const T*>(vx);
auto y = reinterpret_cast<T*>(vy); auto y = reinterpret_cast<const T*>(vy);
auto z = reinterpret_cast<T*>(vz); auto z = reinterpret_cast<T*>(vz);
auto xEws = shape::elementWiseStride(xShapeInfo); auto xEws = shape::elementWiseStride(xShapeInfo);
@ -208,10 +208,10 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename T, typename Lambda> template <typename T, typename Lambda>
static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda) { static _CUDA_G void lambdaTriplewiseKernel(const void* vw, const Nd4jLong *wShapeInfo, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) {
auto w = reinterpret_cast<T*>(vw); auto w = reinterpret_cast<const T*>(vw);
auto x = reinterpret_cast<T*>(vx); auto x = reinterpret_cast<const T*>(vx);
auto y = reinterpret_cast<T*>(vy); auto y = reinterpret_cast<const T*>(vy);
auto z = reinterpret_cast<T*>(vz); auto z = reinterpret_cast<T*>(vz);
auto wEws = shape::elementWiseStride(wShapeInfo); auto wEws = shape::elementWiseStride(wShapeInfo);
@ -271,7 +271,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, Lambda func, NDArray& ta
//throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, target.dataType()); //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, target.dataType());
prepareSpecialUse({&target}, {this, &other}); prepareSpecialUse({&target}, {this, &other});
BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES);
registerSpecialUse({&target}, {this, &other}); registerSpecialUse({&target}, {this, &other});
} }
@ -298,7 +298,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, Lambda func, NDArray& t
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda X/Y/Z data types must be the same"); throw std::runtime_error("NDArray::applyIndexedPairwiseLambda X/Y/Z data types must be the same");
prepareSpecialUse({&target}, {this, &other}); prepareSpecialUse({&target}, {this, &other});
BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES);
registerSpecialUse({&target}, {this, &other}); registerSpecialUse({&target}, {this, &other});
} }

View File

@ -28,26 +28,24 @@
namespace sd { namespace sd {
class ND4J_EXPORT ShapeList { class ND4J_EXPORT ShapeList {
protected: protected:
std::vector<Nd4jLong*> _shapes; std::vector<const Nd4jLong*> _shapes;
bool _destroyed = false; bool _destroyed = false;
bool _autoremovable = false; bool _autoremovable = false;
bool _workspace = false; bool _workspace = false;
public: public:
ShapeList(Nd4jLong* shape = nullptr); ShapeList(const Nd4jLong* shape = nullptr);
ShapeList(std::initializer_list<Nd4jLong*> shapes); ShapeList(const std::vector<const Nd4jLong*> &shapes, bool isWorkspace);
ShapeList(std::initializer_list<Nd4jLong*> shapes, bool isWorkspace); ShapeList(const std::vector<const Nd4jLong*>& shapes);
ShapeList(std::vector<Nd4jLong*>& shapes);
//ShapeList(bool autoRemovable); //ShapeList(bool autoRemovable);
~ShapeList(); ~ShapeList();
std::vector<Nd4jLong*>* asVector(); std::vector<const Nd4jLong*>* asVector();
void destroy(); void destroy();
int size(); int size() const;
Nd4jLong* at(int idx); const Nd4jLong* at(int idx);
void push_back(Nd4jLong *shape); void push_back(const Nd4jLong *shape);
void push_back(std::vector<Nd4jLong>& shape);
/** /**
* PLEASE NOTE: This method should be called ONLY if shapes were generated at workspaces. Otherwise you'll get memory leak * PLEASE NOTE: This method should be called ONLY if shapes were generated at workspaces. Otherwise you'll get memory leak

View File

@ -28,18 +28,18 @@ namespace sd {
private: private:
ConstantDataBuffer _tadShape; ConstantDataBuffer _tadShape;
ConstantDataBuffer _tadOffsets; ConstantDataBuffer _tadOffsets;
Nd4jLong _numTads; Nd4jLong _numTads = 0 ;
int _shapeInfoLength; int _shapeInfoLength = 0;
public: public:
explicit TadPack(ConstantDataBuffer &shapes, ConstantDataBuffer &offets, Nd4jLong numTads); explicit TadPack(ConstantDataBuffer &shapes, ConstantDataBuffer &offets, Nd4jLong numTads);
TadPack() = default; TadPack() = default;
~TadPack() = default; ~TadPack() = default;
Nd4jLong* primaryShapeInfo() const; const Nd4jLong* primaryShapeInfo() const;
Nd4jLong* primaryOffsets() const; const Nd4jLong* primaryOffsets() const;
Nd4jLong* specialShapeInfo() const; const Nd4jLong* specialShapeInfo() const;
Nd4jLong* specialOffsets() const; const Nd4jLong* specialOffsets() const;
Nd4jLong numberOfTads() const; Nd4jLong numberOfTads() const;
int shapeInfoLength() const; int shapeInfoLength() const;
@ -48,8 +48,8 @@ namespace sd {
* These methods return either primary or special pointers depending on platform binaries were compiled for * These methods return either primary or special pointers depending on platform binaries were compiled for
* @return * @return
*/ */
Nd4jLong *platformShapeInfo() const; const Nd4jLong *platformShapeInfo() const;
Nd4jLong *platformOffsets() const; const Nd4jLong *platformOffsets() const;
}; };
} }

View File

@ -52,10 +52,9 @@ namespace sd {
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void* NDArray::platformBuffer() { return buffer(); } void* NDArray::platformBuffer() { return buffer(); }
void* NDArray::getPlatformBuffer() const { return getBuffer(); } void const* NDArray::platformBuffer() const { return buffer(); }
Nd4jLong* NDArray::getPlatformShapeInfo() const { return getShapeInfo(); } Nd4jLong const* NDArray::platformShapeInfo() const { return shapeInfo(); }
Nd4jLong* NDArray::platformShapeInfo() { return shapeInfo(); }
void NDArray::syncToDevice() const { } void NDArray::syncToDevice() const { }
void NDArray::syncToHost() const { } void NDArray::syncToHost() const { }
@ -85,15 +84,15 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& t
upper = target.sizeAt(-1); upper = target.sizeAt(-1);
const T value = static_cast<T>(val); const T value = static_cast<T>(val);
const auto x = reinterpret_cast<const T*>(getBuffer()); const auto x = reinterpret_cast<const T*>(buffer());
auto z = reinterpret_cast<T*>(target.getBuffer()); auto z = reinterpret_cast<T*>(target.buffer());
const int xRank = rankOf(); const int xRank = rankOf();
const int zRank = target.rankOf(); const int zRank = target.rankOf();
const auto zLen = target.lengthOf(); const auto zLen = target.lengthOf();
const bool areSameOffsets = shape::haveSameShapeAndStrides(getShapeInfo(), target.getShapeInfo()); const bool areSameOffsets = shape::haveSameShapeAndStrides(shapeInfo(), target.shapeInfo());
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
@ -101,8 +100,8 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& t
for (auto i = start; i < stop; i++) { for (auto i = start; i < stop; i++) {
shape::index2coordsCPU(start, i, target.getShapeInfo(), coords); shape::index2coordsCPU(start, i, target.shapeInfo(), coords);
const auto zOffset = shape::getOffset(target.getShapeInfo(), coords); const auto zOffset = shape::getOffset(target.shapeInfo(), coords);
// if( (row + upper < col) || (row + lower > col) ) // if( (row + upper < col) || (row + lower > col) )
if ((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1])) if ((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1]))
@ -113,7 +112,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& t
coords[0] = coords[1]; coords[0] = coords[1];
} }
const auto xOffset = areSameOffsets ? zOffset : shape::getOffset(getShapeInfo(), coords); const auto xOffset = areSameOffsets ? zOffset : shape::getOffset(shapeInfo(), coords);
z[zOffset] = x[xOffset]; z[zOffset] = x[xOffset];
if (xRank != zRank) // restore first coordinate if (xRank != zRank) // restore first coordinate
@ -140,7 +139,7 @@ void NDArray::setIdentity() {
for(int j = 0; j < rank; ++j) for(int j = 0; j < rank; ++j)
indices[j] = 1; indices[j] = 1;
Nd4jLong offset = shape::getOffset(getShapeInfo(), indices); Nd4jLong offset = shape::getOffset(shapeInfo(), indices);
for(int i = 0; i < rank; ++i) for(int i = 0; i < rank; ++i)
if(minDim > shape[i]) if(minDim > shape[i])
@ -214,23 +213,28 @@ void NDArray::printCurrentBuffer(const bool host, const char* msg, const int pre
} }
////////////////////////////////////////////////////////////////////////
void* NDArray::specialBufferWithOffset(Nd4jLong offset) {
return nullptr;
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void* NDArray::specialBufferWithOffset(Nd4jLong offset) const { const void* NDArray::specialBufferWithOffset(Nd4jLong offset) const {
return nullptr; return nullptr;
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void* NDArray::specialBuffer() { void* NDArray::specialBuffer() {
if (_buffer->special() == nullptr) if (_buffer->special() == nullptr)
return getBuffer(); return buffer();
// FIXME: this should be fixed once CUDA backend added // FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT()); return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void* NDArray::getSpecialBuffer() const { void const* NDArray::specialBuffer() const {
if (_buffer->special() == nullptr) if (_buffer->special() == nullptr)
return getBuffer(); return buffer();
// FIXME: this should be fixed once CUDA backend added // FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT()); return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
} }
@ -253,7 +257,7 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
NDArray result(*this); NDArray result(*this);
if(diff < 0) { // reshape to higher dimension if(diff < 0) { // reshape to higher dimension
std::vector<Nd4jLong> shapeNew = reps; // there is requirement to have unities at first "diff" positions of new shape std::vector<Nd4jLong> shapeNew = reps; // there is requirement to have unities at first "diff" positions of new shape
memcpy(&shapeNew[-diff], result.getShapeInfo()+1, rankOld * sizeof(Nd4jLong)); // put old shape numbers at rest of positions memcpy(&shapeNew[-diff], result.shapeInfo()+1, rankOld * sizeof(Nd4jLong)); // put old shape numbers at rest of positions
result.reshapei(ordering(), shapeNew); result.reshapei(ordering(), shapeNew);
} }
return result; // nothing to do, if diff >= 0 -> identity tile return result; // nothing to do, if diff >= 0 -> identity tile
@ -274,8 +278,8 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++) { for (auto i = start; i < stop; i++) {
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo()); auto yOffset = shape::subArrayOffset(i, newShapeInfo, shapeInfo());
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), i, this->getBuffer(), yOffset), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.buffer(), i, this->buffer(), yOffset), LIBND4J_TYPES);
} }
}; };
@ -286,8 +290,8 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++) { for (auto i = start; i < stop; i++) {
auto xOffset = result.getOffset(i); auto xOffset = result.getOffset(i);
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo()); auto yOffset = shape::subArrayOffset(i, newShapeInfo, shapeInfo());
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), xOffset, this->getBuffer(), yOffset), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.buffer(), xOffset, this->buffer(), yOffset), LIBND4J_TYPES);
} }
}; };
@ -307,7 +311,7 @@ void NDArray::tile(const std::vector<Nd4jLong>& reps, NDArray& target) const {
// evaluate true tile shapeInfo for comparison with target shapeInfo // evaluate true tile shapeInfo for comparison with target shapeInfo
auto newShapeInfo = ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); auto newShapeInfo = ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace());
if(!shape::equalsSoft(newShapeInfo, target.getShapeInfo())) { if(!shape::equalsSoft(newShapeInfo, target.shapeInfo())) {
delete []newShapeInfo; delete []newShapeInfo;
throw std::runtime_error("NDArray::tile method - shapeInfo of target array is not suitable for tile operation !"); throw std::runtime_error("NDArray::tile method - shapeInfo of target array is not suitable for tile operation !");
} }
@ -319,14 +323,14 @@ void NDArray::tile(const std::vector<Nd4jLong>& reps, NDArray& target) const {
if(target.ordering() == 'c' && ews == 1) { // ews == 1 always here if(target.ordering() == 'c' && ews == 1) { // ews == 1 always here
//#pragma omp parallel for simd if(targetLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided) //#pragma omp parallel for simd if(targetLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)
for(Nd4jLong i=0; i<targetLen; ++i) { for(Nd4jLong i=0; i<targetLen; ++i) {
auto yOffset = shape::subArrayOffset(i, target.getShapeInfo(), getShapeInfo()); auto yOffset = shape::subArrayOffset(i, target.shapeInfo(), shapeInfo());
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.getBuffer(), i, getBuffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.buffer(), i, buffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES);
} }
} }
else if(target.ordering() == 'c' && ews > 1) { else if(target.ordering() == 'c' && ews > 1) {
for(Nd4jLong i=0; i<targetLen; ++i) { for(Nd4jLong i=0; i<targetLen; ++i) {
auto yOffset = shape::subArrayOffset(i, target.getShapeInfo(), getShapeInfo()); auto yOffset = shape::subArrayOffset(i, target.shapeInfo(), shapeInfo());
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.getBuffer(), i*ews, getBuffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.buffer(), i*ews, buffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES);
} }
} }
else { else {
@ -334,8 +338,8 @@ void NDArray::tile(const std::vector<Nd4jLong>& reps, NDArray& target) const {
for(Nd4jLong i=0; i<targetLen; ++i) { for(Nd4jLong i=0; i<targetLen; ++i) {
auto xOffset = target.getOffset(i); auto xOffset = target.getOffset(i);
auto yOffset = shape::subArrayOffset(i, target.getShapeInfo(), getShapeInfo()); auto yOffset = shape::subArrayOffset(i, target.shapeInfo(), shapeInfo());
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.getBuffer(), xOffset, getBuffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.buffer(), xOffset, buffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES);
} }
} }
} }
@ -355,8 +359,8 @@ void NDArray::tile(NDArray& target) const {
if(target.ordering() == 'c' && ews >= 1) { if(target.ordering() == 'c' && ews >= 1) {
for(Nd4jLong i=0; i<targetLen; ++i) { for(Nd4jLong i=0; i<targetLen; ++i) {
auto yOffset = shape::subArrayOffset(i, target.getShapeInfo(), getShapeInfo()); auto yOffset = shape::subArrayOffset(i, target.shapeInfo(), shapeInfo());
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.getBuffer(), i*ews, getBuffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.buffer(), i*ews, buffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES);
} }
} }
else { else {
@ -364,8 +368,8 @@ void NDArray::tile(NDArray& target) const {
for(Nd4jLong i=0; i<targetLen; ++i) { for(Nd4jLong i=0; i<targetLen; ++i) {
auto xOffset = target.getOffset(i); auto xOffset = target.getOffset(i);
auto yOffset = shape::subArrayOffset(i, target.getShapeInfo(), getShapeInfo()); auto yOffset = shape::subArrayOffset(i, target.shapeInfo(), shapeInfo());
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.getBuffer(), xOffset, getBuffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.buffer(), xOffset, buffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES);
} }
} }
} }
@ -388,8 +392,8 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
for (auto i = start; i < stop; i++) { for (auto i = start; i < stop; i++) {
shape::index2coordsCPU(start, i, output.getShapeInfo(), coords); shape::index2coordsCPU(start, i, output.shapeInfo(), coords);
const auto zOffset = shape::getOffset(output.getShapeInfo(), coords); const auto zOffset = shape::getOffset(output.shapeInfo(), coords);
temp = coords[axis]; temp = coords[axis];
@ -404,7 +408,7 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
} else } else
coords[axis] /= repeats[0]; coords[axis] /= repeats[0];
z[zOffset] = x[shape::getOffset(input.getShapeInfo(), coords)]; z[zOffset] = x[shape::getOffset(input.shapeInfo(), coords)];
coords[axis] = temp; coords[axis] = temp;
} }

View File

@ -50,16 +50,16 @@
namespace sd { namespace sd {
void* NDArray::platformBuffer() { return specialBuffer(); } void* NDArray::platformBuffer() { return specialBuffer(); }
void* NDArray::getPlatformBuffer() const { return getSpecialBuffer(); } void const* NDArray::platformBuffer() const { return specialBuffer(); }
Nd4jLong* NDArray::getPlatformShapeInfo() const { return getSpecialShapeInfo(); } Nd4jLong const* NDArray::platformShapeInfo() const { return specialShapeInfo(); }
Nd4jLong* NDArray::platformShapeInfo() { return specialShapeInfo(); } //Nd4jLong const* NDArray::platformShapeInfo() { return specialShapeInfo(); }
void NDArray::syncToDevice() const { void NDArray::syncToDevice() const {
auto currentDeviceId = AffinityManager::currentDeviceId(); auto currentDeviceId = AffinityManager::currentDeviceId();
if (currentDeviceId != _deviceId) { if (currentDeviceId != _deviceId) {
// first of all we update shapeInfo // first of all we update shapeInfo
const_cast<NDArray*>(this)->setShapeInfo(this->getShapeInfo()); const_cast<NDArray*>(this)->setShapeInfo(this->shapeInfo());
// now we actually migrate data buffer // now we actually migrate data buffer
_buffer->migrate(); _buffer->migrate();
@ -142,7 +142,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& t
PointersManager manager(getContext(), "NDArray::fillAsTriangular"); PointersManager manager(getContext(), "NDArray::fillAsTriangular");
NDArray::prepareSpecialUse({&target}, {this}); NDArray::prepareSpecialUse({&target}, {this});
fillAsTriangularCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *getContext()->getCudaStream()>>>(getPlatformBuffer(), getPlatformShapeInfo(), target.getPlatformBuffer(), target.getPlatformShapeInfo(), static_cast<T>(val), lower, upper); fillAsTriangularCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *getContext()->getCudaStream()>>>(platformBuffer(), platformShapeInfo(), target.platformBuffer(), target.platformShapeInfo(), static_cast<T>(val), lower, upper);
NDArray::registerSpecialUse({&target}, {this}); NDArray::registerSpecialUse({&target}, {this});
manager.synchronize(); manager.synchronize();
@ -206,7 +206,7 @@ void NDArray::setIdentity() {
PointersManager manager(getContext(), "NDArray::setIdentity"); PointersManager manager(getContext(), "NDArray::setIdentity");
syncToDevice(); syncToDevice();
BUILD_SINGLE_SELECTOR(dataType(), identityMatrixCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), getPlatformBuffer(), getPlatformShapeInfo(), 1.f), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(dataType(), identityMatrixCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), platformBuffer(), platformShapeInfo(), 1.f), LIBND4J_TYPES);
tickWriteDevice(); tickWriteDevice();
manager.synchronize(); manager.synchronize();
@ -293,12 +293,16 @@ void NDArray::registerPrimaryUse(const std::vector<const NDArray*>& writeList, c
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::syncShape() const { void NDArray::syncShape() const {
cudaMemcpy(getSpecialShapeInfo(), getShapeInfo(), shape::shapeInfoByteLength(getShapeInfo()), cudaMemcpyHostToDevice); cudaMemcpy(const_cast<Nd4jLong*>(specialShapeInfo()), shapeInfo(), shape::shapeInfoByteLength(shapeInfo()), cudaMemcpyHostToDevice);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void* NDArray::specialBufferWithOffset(Nd4jLong offset) const { void const* NDArray::specialBufferWithOffset(Nd4jLong offset) const {
return getSpecialBuffer() != nullptr ? static_cast<int8_t*>(getSpecialBuffer()) + (offset * sizeOfT()) : nullptr; return specialBuffer() != nullptr ? static_cast<int8_t const*>(specialBuffer()) + (offset * sizeOfT()) : nullptr;
}
void* NDArray::specialBufferWithOffset(Nd4jLong offset){
return specialBuffer() != nullptr ? static_cast<int8_t*>(specialBuffer()) + (offset * sizeOfT()) : nullptr;
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -318,7 +322,7 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
NDArray result(*this); NDArray result(*this);
if(diff < 0) { // reshape to higher dimension if(diff < 0) { // reshape to higher dimension
std::vector<Nd4jLong> shapeNew = reps; // need to have unities at first "diff" positions of new shape std::vector<Nd4jLong> shapeNew = reps; // need to have unities at first "diff" positions of new shape
memcpy(&shapeNew[-diff], result.getShapeInfo()+1, rankOld * sizeof(Nd4jLong)); // put old shape numbers at rest of positions memcpy(&shapeNew[-diff], result.shapeInfo()+1, rankOld * sizeof(Nd4jLong)); // put old shape numbers at rest of positions
result.reshapei(ordering(), shapeNew); result.reshapei(ordering(), shapeNew);
} }
return result; // nothing to do, if diff >= 0 -> identity tile return result; // nothing to do, if diff >= 0 -> identity tile
@ -332,13 +336,13 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
NDArray result(newBuff, ShapeDescriptor(newShapeInfo), getContext()); NDArray result(newBuff, ShapeDescriptor(newShapeInfo), getContext());
// fill newBuff, loop through all elements of newBuff // fill newBuff, loop through all elements of newBuff
// looping through getBuffer() goes automatically by means of getSubArrayIndex applying // looping through buffer() goes automatically by means of getSubArrayIndex applying
const auto resultLen = result.lengthOf(); const auto resultLen = result.lengthOf();
auto xType = this->dataType(); auto xType = this->dataType();
auto stream = getContext()->getCudaStream(); auto stream = getContext()->getCudaStream();
prepareSpecialUse({&result}, {this}); prepareSpecialUse({&result}, {this});
BUILD_SINGLE_SELECTOR(xType, tileKernelH, (this->getSpecialBuffer(), this->getSpecialShapeInfo(), result.getSpecialBuffer(), result.getSpecialShapeInfo(), resultLen, stream), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(xType, tileKernelH, (this->specialBuffer(), this->specialShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), resultLen, stream), LIBND4J_TYPES);
registerSpecialUse({&result}, {this}); registerSpecialUse({&result}, {this});
return result; return result;
@ -354,18 +358,18 @@ void NDArray::tile(const std::vector<Nd4jLong>& reps, NDArray& target) const {
// evaluate true tile shapeInfo for comparison with target shapeInfo // evaluate true tile shapeInfo for comparison with target shapeInfo
auto newShapeInfo = ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); auto newShapeInfo = ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace());
if(!shape::equalsSoft(newShapeInfo, target.getShapeInfo())) { if(!shape::equalsSoft(newShapeInfo, target.shapeInfo())) {
throw std::runtime_error("NDArray::tile method - shapeInfo of target array is not suitable for tile operation !"); throw std::runtime_error("NDArray::tile method - shapeInfo of target array is not suitable for tile operation !");
} }
// fill newBuff, loop through all elements of newBuff // fill newBuff, loop through all elements of newBuff
// looping through getBuffer() goes automatically by means of getSubArrayIndex applying // looping through buffer() goes automatically by means of getSubArrayIndex applying
const int ews = target.ews(); const int ews = target.ews();
const int targetLen = target.lengthOf(); const int targetLen = target.lengthOf();
auto stream = getContext()->getCudaStream(); auto stream = getContext()->getCudaStream();
prepareSpecialUse({&target}, {this}); prepareSpecialUse({&target}, {this});
BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (specialBuffer(), specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES);
registerSpecialUse({&target}, {this}); registerSpecialUse({&target}, {this});
} }
@ -384,7 +388,7 @@ void NDArray::tile(NDArray& target) const {
auto stream = getContext()->getCudaStream(); auto stream = getContext()->getCudaStream();
prepareSpecialUse({&target}, {this}); prepareSpecialUse({&target}, {this});
BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (specialBuffer(), specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES);
registerSpecialUse({&target}, {this}); registerSpecialUse({&target}, {this});
} }
@ -467,7 +471,7 @@ NDArray NDArray::repeat(const int axis, const std::vector<int>& repeats) const {
const int* reps = reinterpret_cast<int*>(manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int))); const int* reps = reinterpret_cast<int*>(manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int)));
prepareSpecialUse({&output}, {this}); prepareSpecialUse({&output}, {this});
BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), getSpecialBuffer(), getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), specialBuffer(), specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES);
prepareSpecialUse({&output}, {this}); prepareSpecialUse({&output}, {this});
manager.synchronize(); manager.synchronize();
@ -491,7 +495,7 @@ void NDArray::repeat(const int axis, const std::vector<int>& repeats, NDArray& t
const int* reps = reinterpret_cast<int*>(manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int))); const int* reps = reinterpret_cast<int*>(manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int)));
prepareSpecialUse({&target}, {this}); prepareSpecialUse({&target}, {this});
BUILD_DOUBLE_SELECTOR(dataType(), target.dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), getSpecialBuffer(), getSpecialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_DOUBLE_SELECTOR(dataType(), target.dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), specialBuffer(), specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES, LIBND4J_TYPES);
prepareSpecialUse({&target}, {this}); prepareSpecialUse({&target}, {this});
manager.synchronize(); manager.synchronize();
@ -501,16 +505,20 @@ void NDArray::repeat(const int axis, const std::vector<int>& repeats, NDArray& t
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void* NDArray::specialBuffer() { void* NDArray::specialBuffer() {
if (_buffer->special() == nullptr) if (_buffer->special() == nullptr) {
return getBuffer(); syncToDevice();
tickReadHost();
}
// FIXME: this should be fixed once CUDA backend added // FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT()); return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void* NDArray::getSpecialBuffer() const { void const* NDArray::specialBuffer() const {
if (_buffer->special() == nullptr) if (_buffer->special() == nullptr) {
return getBuffer(); syncToDevice();
tickReadHost();
}
// FIXME: this should be fixed once CUDA backend added // FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT()); return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
} }
@ -526,7 +534,7 @@ void NDArray::printCurrentBuffer(const bool host, const char* msg, const int pre
printf("%s", msg); printf("%s", msg);
if(host) { if(host) {
if(getBuffer() == nullptr || _length == 0) if(buffer() == nullptr || _length == 0)
{ printf("NDArray::printActualBuffer: host buffer is nullptr !\n"); return; } { printf("NDArray::printActualBuffer: host buffer is nullptr !\n"); return; }
const T* buff = bufferAsT<T>(); const T* buff = bufferAsT<T>();
@ -535,7 +543,7 @@ void NDArray::printCurrentBuffer(const bool host, const char* msg, const int pre
printf("\n"); printf("\n");
} }
else { else {
if(getSpecialBuffer() == nullptr || _length == 0) if(specialBuffer() == nullptr || _length == 0)
{ printf("NDArray::printSpecialBuffer: special buffer is nullptr !\n"); return; } { printf("NDArray::printSpecialBuffer: special buffer is nullptr !\n"); return; }
void* pHost = operator new(sizeof(T) * _length); void* pHost = operator new(sizeof(T) * _length);
@ -545,7 +553,7 @@ void NDArray::printCurrentBuffer(const bool host, const char* msg, const int pre
cudaMemcpyAsync(reinterpret_cast<T*>(pHost) + i, specialBufferWithOffset(i), sizeof(T), cudaMemcpyDeviceToHost, *(getContext()->getCudaStream())); cudaMemcpyAsync(reinterpret_cast<T*>(pHost) + i, specialBufferWithOffset(i), sizeof(T), cudaMemcpyDeviceToHost, *(getContext()->getCudaStream()));
} }
else else
cudaMemcpyAsync(pHost, getSpecialBuffer(), sizeOfT() * _length, cudaMemcpyDeviceToHost, *getContext()->getCudaStream()); cudaMemcpyAsync(pHost, specialBuffer(), sizeOfT() * _length, cudaMemcpyDeviceToHost, *getContext()->getCudaStream());
cudaError_t cudaResult = cudaStreamSynchronize(*getContext()->getCudaStream()); cudaError_t cudaResult = cudaStreamSynchronize(*getContext()->getCudaStream());
if(cudaResult != 0) if(cudaResult != 0)

View File

@ -28,7 +28,7 @@ namespace sd {
_floatValues.emplace_back(values[e]); _floatValues.emplace_back(values[e]);
} }
ConstantDescriptor::ConstantDescriptor(Nd4jLong * values, int length) { ConstantDescriptor::ConstantDescriptor(Nd4jLong const* values, int length) {
for (int e = 0; e < length; e++) for (int e = 0; e < length; e++)
_integerValues.emplace_back(values[e]); _integerValues.emplace_back(values[e]);
} }

View File

@ -417,7 +417,7 @@ NDArray NDArrayFactory::create(const std::vector<T> &values, sd::LaunchContext *
NDArray res(buffer, ShapeDescriptor::vectorDescriptor(values.size(), DataTypeUtils::fromT<T>()), context); NDArray res(buffer, ShapeDescriptor::vectorDescriptor(values.size(), DataTypeUtils::fromT<T>()), context);
memcpyFromVector<T>(res.getBuffer(), values); memcpyFromVector<T>(res.buffer(), values);
res.tickWriteHost(); res.tickWriteHost();
res.syncToDevice(); res.syncToDevice();

View File

@ -153,7 +153,7 @@ namespace sd {
inputs[e] = _chunks[e]; inputs[e] = _chunks[e];
} }
auto inShapeInfo = inputs[0]->getShapeInfo(); auto inShapeInfo = inputs[0]->shapeInfo();
int rank = shape::rank(inShapeInfo); int rank = shape::rank(inShapeInfo);
NDArray* array = nullptr; NDArray* array = nullptr;

View File

@ -26,7 +26,7 @@ namespace sd {
// _autoremovable = autoRemovable; // _autoremovable = autoRemovable;
// } // }
ShapeList::ShapeList(Nd4jLong* shape) { ShapeList::ShapeList(const Nd4jLong* shape) {
if (shape != nullptr) if (shape != nullptr)
_shapes.push_back(shape); _shapes.push_back(shape);
} }
@ -36,21 +36,15 @@ namespace sd {
destroy(); destroy();
} }
ShapeList::ShapeList(std::initializer_list<Nd4jLong*> shapes) { ShapeList::ShapeList(const std::vector<const Nd4jLong*> &shapes, bool isWorkspace) : ShapeList(shapes){
for (auto v:shapes)
_shapes.push_back(v);
}
ShapeList::ShapeList(std::initializer_list<Nd4jLong*> shapes, bool isWorkspace) : ShapeList(shapes){
_workspace = isWorkspace; _workspace = isWorkspace;
} }
ShapeList::ShapeList(std::vector<Nd4jLong*>& shapes) { ShapeList::ShapeList(const std::vector<const Nd4jLong*>& shapes) {
for (auto v:shapes) _shapes = shapes;
_shapes.push_back(v);
} }
std::vector<Nd4jLong*>* ShapeList::asVector() { std::vector<const Nd4jLong*>* ShapeList::asVector() {
return &_shapes; return &_shapes;
} }
@ -66,33 +60,21 @@ namespace sd {
_destroyed = true; _destroyed = true;
} }
int ShapeList::size() { int ShapeList::size() const {
return (int) _shapes.size(); return (int) _shapes.size();
} }
Nd4jLong* ShapeList::at(int idx) { const Nd4jLong* ShapeList::at(int idx) {
if (_shapes.size() <= idx) if (_shapes.size() <= idx)
throw std::runtime_error("Can't find requested variable by index"); throw std::runtime_error("Can't find requested variable by index");
return _shapes.at(idx); return _shapes.at(idx);
} }
void ShapeList::push_back(Nd4jLong *shape) { void ShapeList::push_back(const Nd4jLong *shape) {
_shapes.push_back(shape); _shapes.push_back(shape);
} }
void ShapeList::push_back(std::vector<Nd4jLong>& shape) {
int dLen = shape::shapeInfoLength(shape.at(0));
if (shape.size() != dLen)
throw std::runtime_error("Bad shape was passed in");
auto nShape = new Nd4jLong[dLen];
std::memcpy(nShape, shape.data(), shape::shapeInfoByteLength(shape.at(0)));
_shapes.push_back(nShape);
}
void ShapeList::detach() { void ShapeList::detach() {
for (int e = 0; e < _shapes.size(); e++) { for (int e = 0; e < _shapes.size(); e++) {
_shapes[e] = shape::detachShape(_shapes[e]); _shapes[e] = shape::detachShape(_shapes[e]);

View File

@ -29,18 +29,19 @@ namespace sd {
_numTads = numTads; _numTads = numTads;
} }
Nd4jLong* TadPack::primaryShapeInfo() const { const Nd4jLong* TadPack::primaryShapeInfo() const {
return reinterpret_cast<Nd4jLong *>(_tadShape.primary()); return reinterpret_cast<Nd4jLong *>(_tadShape.primary());
} }
Nd4jLong* TadPack::primaryOffsets() const {
const Nd4jLong* TadPack::primaryOffsets() const {
return reinterpret_cast<Nd4jLong *>(_tadOffsets.primary()); return reinterpret_cast<Nd4jLong *>(_tadOffsets.primary());
} }
Nd4jLong* TadPack::specialShapeInfo() const { const Nd4jLong* TadPack::specialShapeInfo() const {
return reinterpret_cast<Nd4jLong *>(_tadShape.special()); return reinterpret_cast<Nd4jLong *>(_tadShape.special());
} }
Nd4jLong* TadPack::specialOffsets() const { const Nd4jLong* TadPack::specialOffsets() const {
return reinterpret_cast<Nd4jLong *>(_tadOffsets.special()); return reinterpret_cast<Nd4jLong *>(_tadOffsets.special());
} }
@ -48,11 +49,11 @@ namespace sd {
return _numTads; return _numTads;
} }
Nd4jLong* TadPack::platformShapeInfo() const { const Nd4jLong* TadPack::platformShapeInfo() const {
return sd::Environment::getInstance()->isCPU() ? primaryShapeInfo() : specialShapeInfo(); return sd::Environment::getInstance()->isCPU() ? primaryShapeInfo() : specialShapeInfo();
} }
Nd4jLong* TadPack::platformOffsets() const { const Nd4jLong* TadPack::platformOffsets() const {
return sd::Environment::getInstance()->isCPU() ? primaryOffsets() : specialOffsets(); return sd::Environment::getInstance()->isCPU() ? primaryOffsets() : specialOffsets();
} }

View File

@ -196,12 +196,14 @@ namespace sd {
#endif #endif
void setInputArray(int index, NDArray *array, bool removable = false); void setInputArray(int index, NDArray *array, bool removable = false);
void setInputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); void setInputArray(int index, void *buffer, void const* shapeInfo, void *specialBuffer, void const* specialShapeInfo);
void setInputArray(int index, void *databuffer, void *shapeInfo, void *specialShapeInfo); void setInputArray(int index, void *buffer, void * shapeInfo, void *specialBuffer, void * specialShapeInfo);
void setInputArray(int index, void *databuffer, void const* shapeInfo, void const* specialShapeInfo);
void setOutputArray(int index, NDArray *array, bool removable = false); void setOutputArray(int index, NDArray *array, bool removable = false);
void setOutputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); void setOutputArray(int index, void *buffer, const void * shapeInfo, void *specialBuffer, const void * specialShapeInfo);
void setOutputArray(int index, void *databuffer, void *shapeInfo, void *specialShapeInfo); void setOutputArray(int index, void *buffer, void * shapeInfo, void *specialBuffer, void * specialShapeInfo);
void setOutputArray(int index, void *databuffer, void const* shapeInfo, void const* specialShapeInfo);
void setTArguments(double *arguments, int numberOfArguments); void setTArguments(double *arguments, int numberOfArguments);
void setIArguments(Nd4jLong *arguments, int numberOfArguments); void setIArguments(Nd4jLong *arguments, int numberOfArguments);

View File

@ -407,8 +407,12 @@ namespace sd {
_handles.emplace_back(array); _handles.emplace_back(array);
} }
void Context::setInputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { void Context::setInputArray(int index, void *buffer, void * shapeInfo, void *specialBuffer, void * specialShapeInfo) {
auto array = new NDArray(buffer, specialBuffer, reinterpret_cast<Nd4jLong *>(shapeInfo)); this->setInputArray(index, buffer, const_cast<const void*>(shapeInfo), specialBuffer, const_cast<const void *>(specialShapeInfo));
}
void Context::setInputArray(int index, void *buffer, void const* shapeInfo, void *specialBuffer, void const* specialShapeInfo) {
auto array = new NDArray(buffer, specialBuffer, reinterpret_cast<Nd4jLong const*>(shapeInfo));
if (_fastpath_in.size() < index + 1) if (_fastpath_in.size() < index + 1)
_fastpath_in.resize(index+1); _fastpath_in.resize(index+1);
@ -430,11 +434,15 @@ namespace sd {
_handles.emplace_back(array); _handles.emplace_back(array);
} }
void Context::setOutputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { void Context::setOutputArray(int index, void *buffer, void * shapeInfo, void *specialBuffer, void * specialShapeInfo) {
this->setOutputArray(index, buffer, const_cast<const void *>(shapeInfo), specialBuffer, const_cast<const void *>(specialShapeInfo));
}
void Context::setOutputArray(int index, void *buffer, const void * shapeInfo, void *specialBuffer, const void * specialShapeInfo) {
if (_fastpath_out.size() < index + 1) if (_fastpath_out.size() < index + 1)
_fastpath_out.resize(index+1); _fastpath_out.resize(index+1);
auto array = new NDArray(buffer, specialBuffer, reinterpret_cast<Nd4jLong *>(shapeInfo)); auto array = new NDArray(buffer, specialBuffer, reinterpret_cast<Nd4jLong const*>(shapeInfo));
_fastpath_out[index] = array; _fastpath_out[index] = array;
_handles.emplace_back(array); _handles.emplace_back(array);
@ -443,7 +451,7 @@ namespace sd {
array->setContext(_context); array->setContext(_context);
} }
void Context::setInputArray(int index, void *vdatabuffer, void *shapeInfo, void *specialShapeInfo) { void Context::setInputArray(int index, void *vdatabuffer, void const* shapeInfo, void const* specialShapeInfo) {
auto dataBuffer = reinterpret_cast<InteropDataBuffer*>(vdatabuffer); auto dataBuffer = reinterpret_cast<InteropDataBuffer*>(vdatabuffer);
if (_fastpath_in.size() < index + 1) if (_fastpath_in.size() < index + 1)
@ -451,9 +459,9 @@ namespace sd {
NDArray *array; NDArray *array;
if (dataBuffer != nullptr) if (dataBuffer != nullptr)
array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast<Nd4jLong *>(shapeInfo), sd::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast<Nd4jLong *>(shapeInfo)))); array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast<Nd4jLong const*>(shapeInfo), sd::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast<Nd4jLong const*>(shapeInfo))));
else else
array = new NDArray(nullptr, nullptr, reinterpret_cast<Nd4jLong *>(shapeInfo)); array = new NDArray(nullptr, nullptr, reinterpret_cast<Nd4jLong const*>(shapeInfo));
_fastpath_in[index] = array; _fastpath_in[index] = array;
_handles.emplace_back(array); _handles.emplace_back(array);
@ -462,7 +470,7 @@ namespace sd {
array->setContext(_context); array->setContext(_context);
} }
void Context::setOutputArray(int index, void *vdatabuffer, void *shapeInfo, void *specialShapeInfo) { void Context::setOutputArray(int index, void *vdatabuffer, void const* shapeInfo, void const* specialShapeInfo) {
auto dataBuffer = reinterpret_cast<InteropDataBuffer*>(vdatabuffer); auto dataBuffer = reinterpret_cast<InteropDataBuffer*>(vdatabuffer);
if (_fastpath_out.size() < index + 1) if (_fastpath_out.size() < index + 1)
@ -470,9 +478,9 @@ namespace sd {
NDArray *array; NDArray *array;
if (dataBuffer != nullptr) if (dataBuffer != nullptr)
array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast<Nd4jLong *>(shapeInfo), sd::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast<Nd4jLong *>(shapeInfo)))); array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast<Nd4jLong const*>(shapeInfo), sd::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast<Nd4jLong const*>(shapeInfo))));
else else
array = new NDArray(nullptr, nullptr, reinterpret_cast<Nd4jLong *>(shapeInfo)); array = new NDArray(nullptr, nullptr, reinterpret_cast<Nd4jLong const*>(shapeInfo));
_fastpath_out[index] = array; _fastpath_out[index] = array;
_handles.emplace_back(array); _handles.emplace_back(array);

View File

@ -50,8 +50,8 @@ namespace sd {
Nd4jLong result = 0L; Nd4jLong result = 0L;
Nd4jLong lastStep = 0L; Nd4jLong lastStep = 0L;
std::vector<Nd4jLong *> shapes; std::vector<Nd4jLong const*> shapes;
MAP_IMPL<std::pair<int,int>, Nd4jLong*> shapesMap; MAP_IMPL<std::pair<int,int>, Nd4jLong const*> shapesMap;
int cntFD = 0; int cntFD = 0;
@ -83,12 +83,12 @@ namespace sd {
auto in = node->input()->at(0); auto in = node->input()->at(0);
auto block = node->getContextPrototype(); auto block = node->getContextPrototype();
std::vector<Nd4jLong*> inputShapes; std::vector<Nd4jLong const*> inputShapes;
int *oldShape; int *oldShape;
for (auto v: *node->input()) { for (auto v: *node->input()) {
nd4j_debug(" inputs for estimation are: %i:%i\n", v.first, v.second); nd4j_debug(" inputs for estimation are: %i:%i\n", v.first, v.second);
if (v.first < 0) { if (v.first < 0) {
inputShapes.push_back(_variableSpace->getVariable(v.first)->getNDArray()->getShapeInfo()); inputShapes.push_back(_variableSpace->getVariable(v.first)->getNDArray()->shapeInfo());
} else { } else {
inputShapes.push_back(shapesMap.at(v)); inputShapes.push_back(shapesMap.at(v));
} }
@ -102,7 +102,7 @@ namespace sd {
int cnt = 0; int cnt = 0;
for (auto newShape: *outSha->asVector()) { for (auto newShape: *outSha->asVector()) {
std::pair<int, int> pairAddr(node->id(), cnt++); std::pair<int, int> pairAddr(node->id(), cnt++);
std::pair<std::pair<int, int>, Nd4jLong*> pairShape(pairAddr, newShape); std::pair<std::pair<int, int>, Nd4jLong const*> pairShape(pairAddr, newShape);
shapesMap.insert(pairShape); shapesMap.insert(pairShape);
@ -122,11 +122,11 @@ namespace sd {
auto x = _variableSpace->getVariable(in); auto x = _variableSpace->getVariable(in);
auto z = _variableSpace->getVariable(node->id()); auto z = _variableSpace->getVariable(node->id());
auto newShape = new Nd4jLong[shape::shapeInfoLength(x->getNDArray()->getShapeInfo())]; auto newShape = new Nd4jLong[shape::shapeInfoLength(x->getNDArray()->shapeInfo())];
memcpy(newShape, x->getNDArray()->getShapeInfo(), shape::shapeInfoByteLength(x->getNDArray()->getShapeInfo())); memcpy(newShape, x->getNDArray()->shapeInfo(), shape::shapeInfoByteLength(x->getNDArray()->shapeInfo()));
std::pair<int, int> pairAddr(node->id(), 0); std::pair<int, int> pairAddr(node->id(), 0);
std::pair<std::pair<int, int>, Nd4jLong*> pairShape(pairAddr, newShape); std::pair<std::pair<int, int>, Nd4jLong const*> pairShape(pairAddr, newShape);
shapesMap.insert(pairShape); shapesMap.insert(pairShape);
@ -141,7 +141,7 @@ namespace sd {
memcpy(newShape, prevShape, shape::shapeInfoByteLength(prevShape)); memcpy(newShape, prevShape, shape::shapeInfoByteLength(prevShape));
std::pair<int, int> pairAddr(node->id(), 0); std::pair<int, int> pairAddr(node->id(), 0);
std::pair<std::pair<int, int>, Nd4jLong*> pairShape(pairAddr, newShape); std::pair<std::pair<int, int>, Nd4jLong const*> pairShape(pairAddr, newShape);
shapesMap.insert(pairShape); shapesMap.insert(pairShape);
@ -152,30 +152,30 @@ namespace sd {
} }
} else if (node->getOpClass() == OpClass_REDUCTION) { } else if (node->getOpClass() == OpClass_REDUCTION) {
Nd4jLong *newShape = nullptr; Nd4jLong const* newShape = nullptr;
// if that's scalar output - we don't care about previous node // if that's scalar output - we don't care about previous node
if (node->getDimensions()->size() == 0 || (node->getDimensions()->size() == 1 && node->getDimensions()->at(0) == sd::DataTypeUtils::max<int>())) { if (node->getDimensions()->size() == 0 || (node->getDimensions()->size() == 1 && node->getDimensions()->at(0) == sd::DataTypeUtils::max<int>())) {
newShape = new Nd4jLong[8]; // auto aNewShape = new Nd4jLong[8];
//
newShape[0] = 2; // aNewShape[0] = 2;
newShape[1] = 1; // aNewShape[1] = 1;
newShape[2] = 1; // aNewShape[2] = 1;
newShape[3] = 1; // aNewShape[3] = 1;
newShape[4] = 1; // aNewShape[4] = 1;
newShape[5] = 8192; // set type as FLOAT32 by default // aNewShape[5] = 8192; // set type as FLOAT32 by default
newShape[6] = 1; // aNewShape[6] = 1;
newShape[7] = 99; // aNewShape[7] = 99;
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(DataType::FLOAT32, 'c', {1,1});
} else { } else {
auto in = node->input()->at(0); auto in = node->input()->at(0);
Nd4jLong *oldShape = nullptr; Nd4jLong const* oldShape = nullptr;
// calculate tads here // calculate tads here
if (in.first < 0) { if (in.first < 0) {
auto x = _variableSpace->getVariable(in)->getNDArray(); auto x = _variableSpace->getVariable(in)->getNDArray();
oldShape = x->getShapeInfo(); oldShape = x->shapeInfo();
} else { } else {
oldShape = shapesMap.at(in); oldShape = shapesMap.at(in);
@ -188,7 +188,7 @@ namespace sd {
} }
std::pair<int, int> pairAddr(node->id(), 0); std::pair<int, int> pairAddr(node->id(), 0);
std::pair<std::pair<int, int>, Nd4jLong*> pairShape(pairAddr, newShape); std::pair<std::pair<int, int>, Nd4jLong const*> pairShape(pairAddr, newShape);
shapesMap.insert(pairShape); shapesMap.insert(pairShape);

View File

@ -88,8 +88,8 @@ namespace sd {
void setObjectsSize(Nd4jLong bytes); void setObjectsSize(Nd4jLong bytes);
void setTotalSize(Nd4jLong bytes); void setTotalSize(Nd4jLong bytes);
void addInputShape(Nd4jLong *shapeInfo); void addInputShape(Nd4jLong const* shapeInfo);
void addOutputShape(Nd4jLong *shapeInfo); void addOutputShape(Nd4jLong const* shapeInfo);
Nd4jLong getActivationsSize() const; Nd4jLong getActivationsSize() const;
Nd4jLong getTemporarySize() const; Nd4jLong getTemporarySize() const;

View File

@ -116,11 +116,11 @@ namespace sd {
return _executionTime; return _executionTime;
} }
void NodeProfile::addInputShape(Nd4jLong *shapeInfo) { void NodeProfile::addInputShape(Nd4jLong const* shapeInfo) {
_inputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo)); _inputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo));
} }
void NodeProfile::addOutputShape(Nd4jLong *shapeInfo) { void NodeProfile::addOutputShape(Nd4jLong const*shapeInfo) {
_outputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo)); _outputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo));
} }

View File

@ -51,20 +51,20 @@ namespace sd {
ConstantDataBuffer bufferForShapeInfo(sd::DataType dataType, char order, const std::vector<Nd4jLong> &shape); ConstantDataBuffer bufferForShapeInfo(sd::DataType dataType, char order, const std::vector<Nd4jLong> &shape);
ConstantDataBuffer bufferForShapeInfo(const ShapeDescriptor &descriptor); ConstantDataBuffer bufferForShapeInfo(const ShapeDescriptor &descriptor);
ConstantDataBuffer bufferForShapeInfo(const Nd4jLong *shapeInfo); ConstantDataBuffer bufferForShapeInfo(const Nd4jLong *shapeInfo);
ConstantDataBuffer bufferForShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape); ConstantDataBuffer bufferForShapeInfo(sd::DataType dataType, char order, int rank, const Nd4jLong* shape);
ConstantDataBuffer createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace = nullptr, const std::vector<int> dimensions = {}); ConstantDataBuffer createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace = nullptr, const std::vector<int> &dimensions = {});
Nd4jLong* emptyShapeInfo(const sd::DataType dataType); const Nd4jLong* emptyShapeInfo(sd::DataType dataType);
Nd4jLong* scalarShapeInfo(const sd::DataType dataType); const Nd4jLong* scalarShapeInfo(sd::DataType dataType);
Nd4jLong* vectorShapeInfo(const Nd4jLong length, const sd::DataType dataType); const Nd4jLong* vectorShapeInfo(Nd4jLong length, sd::DataType dataType);
Nd4jLong* createShapeInfo(const ShapeDescriptor &descriptor); const Nd4jLong* createShapeInfo(const ShapeDescriptor &descriptor);
Nd4jLong* createShapeInfo(const sd::DataType dataType, const char order, const std::vector<Nd4jLong> &shape); const Nd4jLong* createShapeInfo(sd::DataType dataType, char order, const std::vector<Nd4jLong> &shape);
Nd4jLong* createShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape); const Nd4jLong* createShapeInfo(sd::DataType dataType, char order, int rank, const Nd4jLong* shape);
Nd4jLong* createShapeInfo(const sd::DataType dataType, const Nd4jLong* shapeInfo); const Nd4jLong* createShapeInfo(sd::DataType dataType, const Nd4jLong* shapeInfo);
Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace); const Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace);
Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal = true); const Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal = true);
bool checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor); bool checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor);

View File

@ -41,43 +41,43 @@ namespace sd {
public: public:
template <typename OpType> template <typename OpType>
static FORCEINLINE void loopReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, E* extraParams, int64_t start, int64_t stop); static FORCEINLINE void loopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, E* extraParams, int64_t start, int64_t stop);
}; };
template <typename X, typename Z> template <typename X, typename Z>
class ReductionFloatLoops : public ReductionLoops<X, Z, Z> { class ReductionFloatLoops : public ReductionLoops<X, Z, Z> {
public: public:
static void wrapper(const int opNum, X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop); static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop);
template <typename OpType> template <typename OpType>
static void innerloopReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop); static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop);
}; };
template <typename X, typename Z> template <typename X, typename Z>
class ND4J_EXPORT ReductionBoolLoops : public ReductionLoops<X, Z, X> { class ND4J_EXPORT ReductionBoolLoops : public ReductionLoops<X, Z, X> {
public: public:
static void wrapper(const int opNum, X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
template <typename OpType> template <typename OpType>
static void innerloopReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
}; };
template <typename X, typename Z> template <typename X, typename Z>
class ND4J_EXPORT ReductionLongLoops : public ReductionLoops<X, Z, X> { class ND4J_EXPORT ReductionLongLoops : public ReductionLoops<X, Z, X> {
public: public:
static void wrapper(const int opNum, X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
template <typename OpType> template <typename OpType>
static void innerloopReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
}; };
template <typename X> template <typename X>
class ND4J_EXPORT ReductionSameLoops : public ReductionLoops<X, X, X> { class ND4J_EXPORT ReductionSameLoops : public ReductionLoops<X, X, X> {
public: public:
static void wrapper(const int opNum, X* x, Nd4jLong* xShapeInfo, X* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
template <typename OpType> template <typename OpType>
static void innerloopReduce(X* x, Nd4jLong* xShapeInfo, X* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
}; };
@ -85,10 +85,10 @@ namespace sd {
class ND4J_EXPORT IndexReductionLoops { class ND4J_EXPORT IndexReductionLoops {
private: private:
public: public:
static void wrapIndexReduce(const int opNum, void* x, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* extraParams); static void wrapIndexReduce(int opNum, const void* x, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* extraParams);
template <typename OpType> template <typename OpType>
static void loopIndexReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams); static void loopIndexReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams);
}; };
@ -98,7 +98,7 @@ namespace sd {
public: public:
template<typename OpType> template<typename OpType>
static FORCEINLINE void loopTransform(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, E* extraParams, uint64_t threadId, uint64_t numThreads); static FORCEINLINE void loopTransform(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, E* extraParams, uint64_t threadId, uint64_t numThreads);
}; };
template <typename X, typename Z> template <typename X, typename Z>
@ -106,20 +106,20 @@ namespace sd {
public: public:
template <typename OpType> template <typename OpType>
static FORCEINLINE void loopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop); static FORCEINLINE void loopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop);
template <typename OpType> template <typename OpType>
static FORCEINLINE void loopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop); static FORCEINLINE void loopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop);
static void wrapper(const int opNum, X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop); static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop);
static void wrapperAll(const int opNum, X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop); static void wrapperAll(int opNum, const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop);
template <typename OpType> template <typename OpType>
static void innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop); static void innerloopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop);
template <typename OpType> template <typename OpType>
static void innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop); static void innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop);
}; };
@ -263,10 +263,11 @@ namespace sd {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
template<typename X, typename Z, typename E> template<typename X, typename Z, typename E>
template <typename OpType> template <typename OpType>
void sd::ReductionLoops<X, Z, E>::loopReduce(X* x, Nd4jLong* xShapeInfo, void sd::ReductionLoops<X, Z, E>::loopReduce(const X* x, const Nd4jLong* xShapeInfo,
Z* z, Nd4jLong* zShapeInfo, Z* z, const Nd4jLong* zShapeInfo,
Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets,
E* extraParams, int64_t start, int64_t stop) { E* extraParams,
int64_t start, int64_t stop) {
const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopTadXZ(xShapeInfo, zShapeInfo, tadShapeInfo); const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopTadXZ(xShapeInfo, zShapeInfo, tadShapeInfo);
@ -492,9 +493,10 @@ namespace sd {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
template <typename X, typename Z, typename E> template <typename X, typename Z, typename E>
template <typename OpType> template <typename OpType>
void sd::TransformLoops<X, Z, E>::loopTransform(X* x, Nd4jLong* xShapeInfo, void sd::TransformLoops<X, Z, E>::loopTransform(const X* x, const Nd4jLong* xShapeInfo,
Z* z, Nd4jLong* zShapeInfo, Z* z, const Nd4jLong* zShapeInfo,
E* extraParams, uint64_t threadId, uint64_t numThreads) { E* extraParams,
uint64_t threadId, uint64_t numThreads) {
const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo);
@ -682,11 +684,11 @@ namespace sd {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void sd::Reduction3Loops<X, Z>::loopReduce3(X* x, Nd4jLong* xShapeInfo, void sd::Reduction3Loops<X, Z>::loopReduce3(const X* x, const Nd4jLong* xShapeInfo,
X* y, Nd4jLong* yShapeInfo, const X* y, const Nd4jLong* yShapeInfo,
Z* z, Nd4jLong* zShapeInfo, Z* z, const Nd4jLong* zShapeInfo,
int* dims, int dimsLen, int* dims, int dimsLen,
Z* extraParameters, int64_t start, int64_t stop) { Z* extraParameters, int64_t start, int64_t stop) {
// both tads have same shape, however strides and ews may differ // both tads have same shape, however strides and ews may differ
@ -695,7 +697,7 @@ namespace sd {
const Nd4jLong xLen = shape::length(xShapeInfo); const Nd4jLong xLen = shape::length(xShapeInfo);
const Nd4jLong yLen = shape::length(yShapeInfo); const Nd4jLong yLen = shape::length(yShapeInfo);
Nd4jLong* xTadShapeInfo = nullptr, * yTadShapeInfo = nullptr, * xTadOffsets = nullptr, * yTadOffsets = nullptr; const Nd4jLong* xTadShapeInfo = nullptr, * yTadShapeInfo = nullptr, * xTadOffsets = nullptr, * yTadOffsets = nullptr;
TadPack tadPackX, tadPackY; TadPack tadPackX, tadPackY;
std::vector<Nd4jLong> zeroOffsets; std::vector<Nd4jLong> zeroOffsets;
@ -962,12 +964,13 @@ namespace sd {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void sd::Reduction3Loops<X, Z>::loopReduce3All(X* x, Nd4jLong* xShapeInfo, void sd::Reduction3Loops<X, Z>::loopReduce3All(const X* x, const Nd4jLong* xShapeInfo,
X* y, Nd4jLong* yShapeInfo, const X* y, const Nd4jLong* yShapeInfo,
Z* z, Nd4jLong* zShapeInfo, Z* z, const Nd4jLong* zShapeInfo,
Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets,
Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets,
Z* extraParameters, int64_t start, int64_t stop) { Z* extraParameters,
int64_t start, int64_t stop) {
// both tads have same shape, however strides and ews may differ // both tads have same shape, however strides and ews may differ

View File

@ -35,28 +35,28 @@ namespace sd {
static std::vector<Nd4jLong> evalShapeForTensorDot(const NDArray* a, const NDArray* b, const std::vector<int>& axesA, const std::vector<int>& axesB, std::vector<int>& permutAt, std::vector<int>& permutBt, std::vector<Nd4jLong>& shapeAt, std::vector<Nd4jLong>& shapeBt); static std::vector<Nd4jLong> evalShapeForTensorDot(const NDArray* a, const NDArray* b, const std::vector<int>& axesA, const std::vector<int>& axesB, std::vector<int>& permutAt, std::vector<int>& permutBt, std::vector<Nd4jLong>& shapeAt, std::vector<Nd4jLong>& shapeBt);
// evaluate resulting shape after reduce operation // evaluate resulting shape after reduce operation
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const sd::DataType dataType, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); static const Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const sd::DataType dataType, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr);
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const sd::DataType dataType, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); static const Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const sd::DataType dataType, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr);
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); static const Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr);
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); static const Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr);
/** /**
* evaluate output shape for reduce operation when input shape is empty * evaluate output shape for reduce operation when input shape is empty
* behavior is analogous to tf * behavior is analogous to tf
*/ */
static Nd4jLong* evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimensions, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, sd::memory::Workspace* workspace); static const Nd4jLong* evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimensions, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, sd::memory::Workspace* workspace);
// evaluate shape for array which is result of repeat operation applied to arr // evaluate shape for array which is result of repeat operation applied to arr
static std::vector<Nd4jLong> evalRepeatShape(int axis, const std::vector<int>& repeats, const NDArray& arr); static std::vector<Nd4jLong> evalRepeatShape(int axis, const std::vector<int>& repeats, const NDArray& arr);
// evaluate shapeInfo of permuted array // evaluate shapeInfo of permuted array
// if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order // if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order
static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides = false); static const Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides = false);
static Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace); static const Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace);
// evaluate shapeInfo of transposed array // evaluate shapeInfo of transposed array
// if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order // if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order
static Nd4jLong* evalTranspShapeInfo(const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides = false); static const Nd4jLong* evalTranspShapeInfo(const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides = false);
static bool copyVectorPart(std::vector<int>& target, std::vector<int>& source, int rank, int offset); static bool copyVectorPart(std::vector<int>& target, std::vector<int>& source, int rank, int offset);
@ -67,13 +67,13 @@ namespace sd {
// check whether 2 arrays have mutually broadcastable shapes // check whether 2 arrays have mutually broadcastable shapes
// shape comparison starts from the end // shape comparison starts from the end
static bool areShapesBroadcastable(const NDArray &arr1, const NDArray &arr2); static bool areShapesBroadcastable(const NDArray &arr1, const NDArray &arr2);
static bool areShapesBroadcastable(Nd4jLong* shapeX, Nd4jLong* shapeY); static bool areShapesBroadcastable(const Nd4jLong* shapeX, const Nd4jLong* shapeY);
static bool areShapesBroadcastable(const std::vector<Nd4jLong>& shape1, const std::vector<Nd4jLong>& shape2); static bool areShapesBroadcastable(const std::vector<Nd4jLong>& shape1, const std::vector<Nd4jLong>& shape2);
// check the possibility of broadcast operation, if true then return shapeInfo of resulting array // check the possibility of broadcast operation, if true then return shapeInfo of resulting array
// if evalMinMax == false then array with larger rank has to be passed as first argument // if evalMinMax == false then array with larger rank has to be passed as first argument
static bool evalBroadcastShapeInfo(const NDArray& max, const NDArray& min, const bool evalMinMax, Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace); static bool evalBroadcastShapeInfo(const NDArray& max, const NDArray& min, const bool evalMinMax, const Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace);
static bool evalBroadcastShapeInfo(Nd4jLong *max, Nd4jLong *min, const bool evalMinMax, Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace); static bool evalBroadcastShapeInfo(const Nd4jLong *max, const Nd4jLong *min, const bool evalMinMax, const Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace);
// evaluate sorted vector of max axes to create tads along in case of simple broadcast operation // evaluate sorted vector of max axes to create tads along in case of simple broadcast operation
// if simple broadcast is not possible then empty vector is returned // if simple broadcast is not possible then empty vector is returned
@ -88,10 +88,10 @@ namespace sd {
static std::vector<int> getDimsWithSameShape(const NDArray& max, const NDArray& min); static std::vector<int> getDimsWithSameShape(const NDArray& max, const NDArray& min);
// evaluate shapeInfo for resulting array of tile operation // evaluate shapeInfo for resulting array of tile operation
static Nd4jLong* evalTileShapeInfo(const NDArray& arr, const std::vector<Nd4jLong>& reps, sd::memory::Workspace* workspace); static const Nd4jLong* evalTileShapeInfo(const NDArray& arr, const std::vector<Nd4jLong>& reps, sd::memory::Workspace* workspace);
// returns shape part of shapeInfo as std::vector // returns shape part of shapeInfo as std::vector
static std::vector<Nd4jLong> pullShapeFromShapeInfo(Nd4jLong *shapeInfo); static std::vector<Nd4jLong> pullShapeFromShapeInfo(const Nd4jLong *shapeInfo);
static std::string shapeAsString(const NDArray* array); static std::string shapeAsString(const NDArray* array);
static std::string shapeAsString(const std::vector<Nd4jLong>& shape); static std::string shapeAsString(const std::vector<Nd4jLong>& shape);
@ -104,13 +104,13 @@ namespace sd {
static std::vector<Nd4jLong> shapeAsVector(const Nd4jLong* shapeInfo); static std::vector<Nd4jLong> shapeAsVector(const Nd4jLong* shapeInfo);
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal // evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
static Nd4jLong* evalDiagShapeInfo(const Nd4jLong* shapeInfo, sd::memory::Workspace* workspace); static const Nd4jLong* evalDiagShapeInfo(const Nd4jLong* shapeInfo, sd::memory::Workspace* workspace);
static std::vector<int> evalBroadcastBackwardAxis(const Nd4jLong *operand, const Nd4jLong *result); static std::vector<int> evalBroadcastBackwardAxis(const Nd4jLong *operand, const Nd4jLong *result);
// utility to calculate matrix product shape with give source shapes and additional params // utility to calculate matrix product shape with give source shapes and additional params
// returns ShapeList pointer with result shape // returns ShapeList pointer with result shape
static Nd4jLong* matrixProductShape(Nd4jLong* theFirstShape, Nd4jLong* theSecondShape, bool shouldTranspondFirst, bool shouldTranspondSecond, sd::DataType dtype, sd::memory::Workspace* workspace); static const Nd4jLong* matrixProductShape(const Nd4jLong* theFirstShape, const Nd4jLong* theSecondShape, bool shouldTranspondFirst, bool shouldTranspondSecond, sd::DataType dtype, sd::memory::Workspace* workspace);
/** /**
* This method evaluates permutation vector necessary for reducing of shapeFrom to shapeTo * This method evaluates permutation vector necessary for reducing of shapeFrom to shapeTo

View File

@ -55,20 +55,20 @@ namespace shape {
Nd4jLong tadIndex = 0; Nd4jLong tadIndex = 0;
int dimensionLength; int dimensionLength;
int* dimension = nullptr; int* dimension = nullptr;
Nd4jLong *shapeInfo = nullptr; Nd4jLong const* shapeInfo = nullptr;
Nd4jLong *tadOnlyShapeInfo = nullptr; Nd4jLong* tadOnlyShapeInfo = nullptr;
Nd4jLong numTads = 0; Nd4jLong numTads = 0;
int tadRank = 0; int tadRank = 0;
Nd4jLong *tadShape = nullptr; Nd4jLong* tadShape = nullptr;
Nd4jLong *tadStride = nullptr; Nd4jLong* tadStride = nullptr;
Nd4jLong *tadOffsets = nullptr; Nd4jLong* tadOffsets = nullptr;
Nd4jLong tadOffsetForBlock = 0; Nd4jLong tadOffsetForBlock = 0;
int rank = 0; int rank = 0;
int numOnes = 0; int numOnes = 0;
//pointers to original //pointers to original
int originalDimensionLength; int originalDimensionLength;
int *originalDimension = nullptr; int const* originalDimension = nullptr;
Nd4jLong *originalShapeInfo = nullptr; Nd4jLong const* originalShapeInfo = nullptr;
bool squeezed = false; bool squeezed = false;
bool newSqueezeDimensions = false; bool newSqueezeDimensions = false;
int numOnesInMiddle = 0; int numOnesInMiddle = 0;
@ -81,7 +81,7 @@ namespace shape {
void *ptrManager = nullptr; void *ptrManager = nullptr;
int *ptrOutput = nullptr; int *ptrOutput = nullptr;
INLINEDEF bool dimensionsDescending(int rank, int *dimensions, int length); INLINEDEF bool dimensionsDescending(int rank, int const* dimensions, int length);
#ifdef __CUDACC__ #ifdef __CUDACC__
__host__ __device__ __host__ __device__
@ -114,12 +114,12 @@ namespace shape {
#ifdef __CUDACC__ #ifdef __CUDACC__
__host__ __device__ __host__ __device__
#endif #endif
INLINEDEF void init(Nd4jLong *shapeInfo,int *dimension,int dimensionLength); INLINEDEF void init(Nd4jLong const* shapeInfo,int const* dimension,int dimensionLength);
#ifdef __CUDACC__ #ifdef __CUDACC__
__host__ __device__ __host__ __device__
#endif #endif
INLINEDEF void init(int index, Nd4jLong *shapeInfo,int *dimension,int dimensionLength); INLINEDEF void init(int index, Nd4jLong const* shapeInfo,int const* dimension,int dimensionLength);
@ -134,12 +134,12 @@ namespace shape {
#ifdef __CUDACC__ #ifdef __CUDACC__
__host__ __device__ __host__ __device__
#endif #endif
INLINEDEF void permuteShapeBufferInPlace(Nd4jLong *shapeBuffer, int* rearrange, Nd4jLong *out); INLINEDEF void permuteShapeBufferInPlace(Nd4jLong const* shapeBuffer, int const* rearrange, Nd4jLong *out);
#ifdef __CUDACC__ #ifdef __CUDACC__
__host__ __device__ __host__ __device__
#endif #endif
INLINEDEF Nd4jLong* permuteShapeBuffer(Nd4jLong *shapeBuffer, int *rearrange); INLINEDEF Nd4jLong* permuteShapeBuffer(Nd4jLong const* shapeBuffer, int *rearrange);
@ -153,7 +153,7 @@ namespace shape {
#ifdef __CUDACC__ #ifdef __CUDACC__
__host__ __device__ __host__ __device__
#endif #endif
INLINEDEF Nd4jLong lengthPerSlice(Nd4jLong *shapeBuffer); INLINEDEF Nd4jLong lengthPerSlice(Nd4jLong const* shapeBuffer);
#ifdef __CUDACC__ #ifdef __CUDACC__
@ -253,7 +253,7 @@ namespace shape {
#ifdef __CUDACC__ #ifdef __CUDACC__
__host__ __device__ __host__ __device__
#endif #endif
INLINEDEF Nd4jLong tadLength(Nd4jLong *shapeInfo, int *dimension, int dimensionLength); INLINEDEF Nd4jLong tadLength(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength);
/** /**
* Computes the number * Computes the number
@ -263,7 +263,7 @@ namespace shape {
#ifdef __CUDACC__ #ifdef __CUDACC__
__host__ __device__ __host__ __device__
#endif #endif
INLINEDEF Nd4jLong tensorsAlongDimension(Nd4jLong *shapeInfo, int *dimension, int dimensionLength); INLINEDEF Nd4jLong tensorsAlongDimension(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength);
#ifdef __CUDACC__ #ifdef __CUDACC__
@ -337,19 +337,19 @@ namespace shape {
this->wholeThing = this->numTads == 1 || ((this->dimensionLength == this->rank || this->numTads == shape::length(this->shapeInfo)) && ews == 1); this->wholeThing = this->numTads == 1 || ((this->dimensionLength == this->rank || this->numTads == shape::length(this->shapeInfo)) && ews == 1);
} }
INLINEDEF void TAD::init(int tadIndex, Nd4jLong *shapeInfo,int *dimension,int dimensionLength) { INLINEDEF void TAD::init(int tadIndex, Nd4jLong const* shapeInfo,int const* dimension,int dimensionLength) {
this->tadIndex = tadIndex; this->tadIndex = tadIndex;
this->init(shapeInfo, dimension, dimensionLength); this->init(shapeInfo, dimension, dimensionLength);
} }
INLINEDEF void TAD::init(Nd4jLong *shapeInfo, int *dimension,int dimensionLength) { INLINEDEF void TAD::init(Nd4jLong const* shapeInfo, int const* dimension,int dimensionLength) {
this->originalShapeInfo = shapeInfo; this->originalShapeInfo = shapeInfo;
this->originalDimension = dimension; this->originalDimension = dimension;
this->originalDimensionLength = dimensionLength; this->originalDimensionLength = dimensionLength;
//start off as original references //start off as original references
this->shapeInfo = shapeInfo; this->shapeInfo = shapeInfo;
this->dimensionLength = dimensionLength; this->dimensionLength = dimensionLength;
this->dimension = dimension; this->dimension = const_cast<int*>(dimension);
this->rank = shape::rank(shapeInfo); this->rank = shape::rank(shapeInfo);
this->numTads = dimensionLength == 0 ? 1 : this->tensorsAlongDimension(this->shapeInfo, this->dimension, this->dimensionLength); this->numTads = dimensionLength == 0 ? 1 : this->tensorsAlongDimension(this->shapeInfo, this->dimension, this->dimensionLength);
@ -420,19 +420,19 @@ namespace shape {
} }
INLINEDEF void TAD::permuteShapeBufferInPlace(Nd4jLong* shapeBuffer, int* rearrange, Nd4jLong* out) { INLINEDEF void TAD::permuteShapeBufferInPlace(Nd4jLong const* shapeBuffer, int const* rearrange, Nd4jLong* out) {
memcpy(out, shapeBuffer, sizeof(Nd4jLong) * shape::shapeInfoLength(this->rank)); memcpy(out, shapeBuffer, sizeof(Nd4jLong) * shape::shapeInfoLength(this->rank));
doPermuteShapeInfo(out, rearrange); doPermuteShapeInfo(out, rearrange);
} }
INLINEDEF Nd4jLong* TAD::permuteShapeBuffer(Nd4jLong* shapeBuffer, int *rearrange) { INLINEDEF Nd4jLong* TAD::permuteShapeBuffer(Nd4jLong const* shapeBuffer, int *rearrange) {
int len = shape::shapeInfoLength(this->rank); int len = shape::shapeInfoLength(this->rank);
Nd4jLong *copy = shape::copyOf(len,shapeBuffer); Nd4jLong *copy = shape::copyOf(len,shapeBuffer);
doPermuteShapeInfo(copy,rearrange); doPermuteShapeInfo(copy,rearrange);
return copy; return copy;
} }
INLINEDEF bool TAD::dimensionsDescending(int rank, int *dimensions, int length) { INLINEDEF bool TAD::dimensionsDescending(int rank, int const* dimensions, int length) {
int desired = rank - 1; int desired = rank - 1;
for (int e = length - 1; e >= 0; e--) { for (int e = length - 1; e >= 0; e--) {
if (dimensions[e] != desired--) if (dimensions[e] != desired--)
@ -465,7 +465,7 @@ namespace shape {
this->tadStride = shape::stride(this->tadOnlyShapeInfo); this->tadStride = shape::stride(this->tadOnlyShapeInfo);
} }
INLINEDEF Nd4jLong TAD::lengthPerSlice(Nd4jLong* shapeBuffer) { INLINEDEF Nd4jLong TAD::lengthPerSlice(Nd4jLong const* shapeBuffer) {
int dimension = 0; int dimension = 0;
Nd4jLong *remove = shape::removeIndex(shape::shapeOf(shapeBuffer),&dimension,shape::rank(shapeBuffer),1); Nd4jLong *remove = shape::removeIndex(shape::shapeOf(shapeBuffer),&dimension,shape::rank(shapeBuffer),1);
Nd4jLong prod = shape::prodLong(remove, shape::rank(shapeBuffer) - 1); Nd4jLong prod = shape::prodLong(remove, shape::rank(shapeBuffer) - 1);
@ -635,7 +635,7 @@ namespace shape {
} }
INLINEDEF Nd4jLong* TAD::tensorShape() { INLINEDEF Nd4jLong* TAD::tensorShape(){
if(this->tadShape != nullptr) if(this->tadShape != nullptr)
return this->tadShape; return this->tadShape;
@ -902,7 +902,7 @@ namespace shape {
} }
INLINEDEF Nd4jLong TAD::tadLength(Nd4jLong *shapeInfo, int *dimension, int dimensionLength) { INLINEDEF Nd4jLong TAD::tadLength(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength) {
if(dimensionLength == 1) { if(dimensionLength == 1) {
return shape::shapeOf(shapeInfo)[dimension[0]]; return shape::shapeOf(shapeInfo)[dimension[0]];
} }
@ -919,7 +919,7 @@ namespace shape {
} }
INLINEDEF Nd4jLong TAD::tensorsAlongDimension(Nd4jLong *shapeInfo, int *dimension, int dimensionLength) { INLINEDEF Nd4jLong TAD::tensorsAlongDimension(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength) {
return shape::length(shapeInfo) / this->tadLength(shapeInfo,dimension,dimensionLength); return shape::length(shapeInfo) / this->tadLength(shapeInfo,dimension,dimensionLength);
} }

View File

@ -55,22 +55,16 @@ namespace sd {
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
int deviceId = 0; int deviceId = 0;
_mutex.lock(); std::lock_guard<std::mutex> lock(_mutex);
if (_cache[deviceId].count(descriptor) == 0) { if (_cache[deviceId].count(descriptor) == 0) {
auto hPtr = descriptor.toShapeInfo(); auto hPtr = descriptor.toShapeInfo();
ConstantDataBuffer buffer(hPtr, nullptr, shape::shapeInfoLength(hPtr)*sizeof(Nd4jLong), DataType::INT64); ConstantDataBuffer buffer(hPtr, nullptr, shape::shapeInfoLength(hPtr)*sizeof(Nd4jLong), DataType::INT64);
ShapeDescriptor descriptor1(descriptor); ShapeDescriptor descriptor1(descriptor);
_cache[deviceId][descriptor1] = buffer; _cache[deviceId][descriptor1] = buffer;
auto r = _cache[deviceId][descriptor1]; return _cache[deviceId][descriptor1];
_mutex.unlock();
return r;
} else { } else {
auto r = _cache[deviceId].at(descriptor); return _cache[deviceId].at(descriptor);
_mutex.unlock();
return r;
} }
} }
@ -82,52 +76,45 @@ namespace sd {
bool ConstantShapeHelper::checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor) { bool ConstantShapeHelper::checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor) {
bool result; bool result;
int deviceId = 0; int deviceId = 0;
_mutex.lock(); std::lock_guard<std::mutex> lock(_mutex);
if (_cache[deviceId].count(descriptor) == 0) return _cache[deviceId].count(descriptor) != 0;
result = false;
else
result = true;
_mutex.unlock();
return result;
} }
Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { const Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
ShapeDescriptor descriptor(dataType, order, shape, rank); ShapeDescriptor descriptor(dataType, order, shape, rank);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
} }
Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const Nd4jLong* shapeInfo) { const Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const Nd4jLong* shapeInfo) {
return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast<Nd4jLong*>(shapeInfo))); return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast<Nd4jLong*>(shapeInfo)));
} }
Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const sd::DataType dataType) { const Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const sd::DataType dataType) {
auto descriptor = ShapeDescriptor::emptyDescriptor(dataType); auto descriptor = ShapeDescriptor::emptyDescriptor(dataType);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
} }
Nd4jLong* ConstantShapeHelper::scalarShapeInfo(const sd::DataType dataType) { const Nd4jLong* ConstantShapeHelper::scalarShapeInfo(const sd::DataType dataType) {
auto descriptor = ShapeDescriptor::scalarDescriptor(dataType); auto descriptor = ShapeDescriptor::scalarDescriptor(dataType);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
} }
Nd4jLong* ConstantShapeHelper::vectorShapeInfo(const Nd4jLong length, const sd::DataType dataType) { const Nd4jLong* ConstantShapeHelper::vectorShapeInfo(const Nd4jLong length, const sd::DataType dataType) {
auto descriptor = ShapeDescriptor::vectorDescriptor(length, dataType); auto descriptor = ShapeDescriptor::vectorDescriptor(length, dataType);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
} }
Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const std::vector<Nd4jLong> &shape) { const Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const std::vector<Nd4jLong> &shape) {
ShapeDescriptor descriptor(dataType, order, shape); ShapeDescriptor descriptor(dataType, order, shape);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
} }
Nd4jLong* ConstantShapeHelper::createShapeInfo(const ShapeDescriptor &descriptor) { const Nd4jLong* ConstantShapeHelper::createShapeInfo(const ShapeDescriptor &descriptor) {
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
} }
Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal) { const Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal) {
ShapeDescriptor descriptor(shapeInfo); ShapeDescriptor descriptor(shapeInfo);
auto result = createShapeInfo(descriptor); auto result = createShapeInfo(descriptor);
@ -137,7 +124,7 @@ namespace sd {
return result; return result;
} }
Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace) { const Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace) {
ShapeDescriptor descriptor(shapeInfo); ShapeDescriptor descriptor(shapeInfo);
auto result = createShapeInfo(descriptor); auto result = createShapeInfo(descriptor);
@ -148,7 +135,7 @@ namespace sd {
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector<int> dimensions) { ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector<int> &dimensions) {
Nd4jLong* newShapeInfo = nullptr; Nd4jLong* newShapeInfo = nullptr;
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong); ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong);

View File

@ -44,9 +44,9 @@ static void usualGemm(const NDArray* vA, const NDArray* vB, NDArray* vC,
const bool betaPersent = beta; const bool betaPersent = beta;
const Nd4jLong* aShapeInfo = vA->getShapeInfo(); const Nd4jLong* aShapeInfo = vA->shapeInfo();
const Nd4jLong* bShapeInfo = vB->getShapeInfo(); const Nd4jLong* bShapeInfo = vB->shapeInfo();
const Nd4jLong* cShapeInfo = vC->getShapeInfo(); const Nd4jLong* cShapeInfo = vC->shapeInfo();
const int aRank = vA->rankOf(); const int aRank = vA->rankOf();
const int bRank = vB->rankOf(); const int bRank = vB->rankOf();
@ -111,9 +111,9 @@ static void usualGemv(const NDArray* vA, const NDArray* vX, NDArray* vY, const
const bool betaPersent = beta; const bool betaPersent = beta;
const Nd4jLong* aShapeInfo = vA->getShapeInfo(); const Nd4jLong* aShapeInfo = vA->shapeInfo();
const Nd4jLong* xShapeInfo = vX->getShapeInfo(); const Nd4jLong* xShapeInfo = vX->shapeInfo();
const Nd4jLong* yShapeInfo = vY->getShapeInfo(); const Nd4jLong* yShapeInfo = vY->shapeInfo();
const int N = vX->lengthOf(); const int N = vX->lengthOf();
const int M = vY->lengthOf(); const int M = vY->lengthOf();
@ -294,13 +294,13 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y,
if(A->rankOf() != 2) if(A->rankOf() != 2)
throw std::runtime_error("MmulHelper::mmulMxV: rank of A array is not equal 2 !"); throw std::runtime_error("MmulHelper::mmulMxV: rank of A array is not equal 2 !");
if(!shape::isCommonVector(X->getShapeInfo(), xLenDim)) if(!shape::isCommonVector(X->shapeInfo(), xLenDim))
throw std::runtime_error("MmulHelper::mmulMxV: X array must be vector !"); throw std::runtime_error("MmulHelper::mmulMxV: X array must be vector !");
const auto M = A->sizeAt(0); const auto M = A->sizeAt(0);
const auto N = A->sizeAt(1); const auto N = A->sizeAt(1);
if(Y != nullptr && !shape::isCommonVector(Y->getShapeInfo(), yLenDim)) if(Y != nullptr && !shape::isCommonVector(Y->shapeInfo(), yLenDim))
throw std::runtime_error("MmulHelper::mmulMxV: Y array must be vector !"); throw std::runtime_error("MmulHelper::mmulMxV: Y array must be vector !");
if(X->lengthOf() != N) if(X->lengthOf() != N)
throw std::runtime_error("MmulHelper::mmulMxV: X vector has wrong length !"); throw std::runtime_error("MmulHelper::mmulMxV: X vector has wrong length !");
@ -347,10 +347,10 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y,
// choose appropriate cuda gemm api depending on data types // choose appropriate cuda gemm api depending on data types
if(typeDouble) { if(typeDouble) {
BlasHelper::getInstance()->dgemv()(blasOrder, CblasNoTrans, M, N, alpha, (double*)pA->getBuffer(), lda, (double*)X->getBuffer(), incx, beta, (double*)Y->getBuffer(), incy); BlasHelper::getInstance()->dgemv()(blasOrder, CblasNoTrans, M, N, alpha, (double*)pA->buffer(), lda, (double*)X->buffer(), incx, beta, (double*)Y->buffer(), incy);
} }
else if(typeFloat) { else if(typeFloat) {
BlasHelper::getInstance()->sgemv()(blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->getBuffer(), lda, (float*)X->getBuffer(), incx, (float)beta, (float*)Y->getBuffer(), incy); BlasHelper::getInstance()->sgemv()(blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->buffer(), lda, (float*)X->buffer(), incx, (float)beta, (float*)Y->buffer(), incy);
} }
if(pA != A) if(pA != A)
@ -371,9 +371,9 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con
int xLenDim(0), yLenDim(0); int xLenDim(0), yLenDim(0);
if(!shape::isCommonVector(X->getShapeInfo(), xLenDim)) if(!shape::isCommonVector(X->shapeInfo(), xLenDim))
throw std::runtime_error("MmulHelper::dot: X array must be vector !"); throw std::runtime_error("MmulHelper::dot: X array must be vector !");
if(!shape::isCommonVector(Y->getShapeInfo(), yLenDim)) if(!shape::isCommonVector(Y->shapeInfo(), yLenDim))
throw std::runtime_error("MmulHelper::dot: Y array must be vector !"); throw std::runtime_error("MmulHelper::dot: Y array must be vector !");
if(Z != nullptr && !Z->isScalar()) if(Z != nullptr && !Z->isScalar())
throw std::runtime_error("MmulHelper::dot: Z array must be scalar !"); throw std::runtime_error("MmulHelper::dot: Z array must be scalar !");
@ -393,8 +393,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con
const auto yType = Y->dataType(); const auto yType = Y->dataType();
const auto zType = Z->dataType(); const auto zType = Z->dataType();
BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), NUMERIC_TYPES); BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (length, alpha, X->buffer(), incx, Y->buffer(), incy, beta, Z->buffer()), NUMERIC_TYPES);
//BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, X->buffer(), incx, Y->buffer(), incy, beta, Z->buffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
return Z; return Z;
} }
@ -419,9 +419,9 @@ static void batchedGemm(const NDArray* vA, const NDArray* vB, NDArray* vC,
const bool betaPersent = beta; const bool betaPersent = beta;
const Nd4jLong* aShapeInfo = vA->getShapeInfo(); const Nd4jLong* aShapeInfo = vA->shapeInfo();
const Nd4jLong* bShapeInfo = vB->getShapeInfo(); const Nd4jLong* bShapeInfo = vB->shapeInfo();
const Nd4jLong* cShapeInfo = vC->getShapeInfo(); const Nd4jLong* cShapeInfo = vC->shapeInfo();
const int aRank = vA->rankOf(); const int aRank = vA->rankOf();
const int bRank = vB->rankOf(); const int bRank = vB->rankOf();
@ -576,13 +576,13 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con
// multiplication // multiplication
const std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1}); const std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1});
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->getShapeInfo(), dimsToExclude); const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->shapeInfo(), dimsToExclude);
std::vector<Nd4jLong> idxRanges(2 * C->rankOf()); std::vector<Nd4jLong> idxRanges(2 * C->rankOf());
// #pragma omp parallel for schedule(guided) firstprivate(idxRanges) // #pragma omp parallel for schedule(guided) firstprivate(idxRanges)
for(Nd4jLong i = 0; i < numOfSubArrs; ++i) { for(Nd4jLong i = 0; i < numOfSubArrs; ++i) {
ShapeUtils::evalIdxRangesForSubArr(i, C->getShapeInfo(), dimsToExclude, idxRanges.data()); ShapeUtils::evalIdxRangesForSubArr(i, C->shapeInfo(), dimsToExclude, idxRanges.data());
NDArray cSubArr = (*C)(idxRanges); NDArray cSubArr = (*C)(idxRanges);
if(aRank > bRank) { if(aRank > bRank) {

View File

@ -26,10 +26,10 @@ using namespace simdOps;
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType> template <typename OpType>
void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, void sd::IndexReductionLoops<X,Z>::loopIndexReduce(const X* x, const Nd4jLong* xShapeInfo,
Z* z, Nd4jLong* zShapeInfo, Z* z, const Nd4jLong* zShapeInfo,
Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets,
X* extraParams) { X* extraParams) {
sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(xShapeInfo, zShapeInfo, tadShapeInfo); sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(xShapeInfo, zShapeInfo, tadShapeInfo);
if(kindOfLoop == sd::LoopKind::SMALLARR2DX) if(kindOfLoop == sd::LoopKind::SMALLARR2DX)
@ -305,8 +305,8 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
} }
template <typename X, typename Y> template <typename X, typename Y>
void sd::IndexReductionLoops<X, Y>::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* vz, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams) { void sd::IndexReductionLoops<X, Y>::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Y *>(vz); auto z = reinterpret_cast<Y *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_0, (sd::DataType::INT32, int32_t)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_0, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_1, (sd::DataType::INT32, int32_t)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_1, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_2, (sd::DataType::INT32, int32_t)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_2, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_3, (sd::DataType::INT32, int32_t)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_3, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_4, (sd::DataType::INT32, int32_t)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_4, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_5, (sd::DataType::INT32, int32_t)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_5, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_6, (sd::DataType::INT32, int32_t)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_6, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_7, (sd::DataType::INT32, int32_t)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_7, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_8, (sd::DataType::INT32, int32_t)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_8, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_9, (sd::DataType::INT32, int32_t)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_9, (sd::DataType::INT32, int32_t));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_0, (sd::DataType::INT64, Nd4jLong)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_0, (sd::DataType::INT64, Nd4jLong));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_1, (sd::DataType::INT64, Nd4jLong)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_1, (sd::DataType::INT64, Nd4jLong));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_2, (sd::DataType::INT64, Nd4jLong)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_2, (sd::DataType::INT64, Nd4jLong));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_3, (sd::DataType::INT64, Nd4jLong)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_3, (sd::DataType::INT64, Nd4jLong));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_4, (sd::DataType::INT64, Nd4jLong)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_4, (sd::DataType::INT64, Nd4jLong));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_5, (sd::DataType::INT64, Nd4jLong)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_5, (sd::DataType::INT64, Nd4jLong));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_6, (sd::DataType::INT64, Nd4jLong)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_6, (sd::DataType::INT64, Nd4jLong));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_7, (sd::DataType::INT64, Nd4jLong)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_7, (sd::DataType::INT64, Nd4jLong));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_8, (sd::DataType::INT64, Nd4jLong)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_8, (sd::DataType::INT64, Nd4jLong));

View File

@ -21,4 +21,4 @@
#include "./IndexReductionLoops.hpp" #include "./IndexReductionLoops.hpp"
BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_9, (sd::DataType::INT64, Nd4jLong)); BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_9, (sd::DataType::INT64, Nd4jLong));

View File

@ -28,7 +28,7 @@ namespace sd {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) { void Reduction3Loops<X,Z>::innerloopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop); Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop);
#endif #endif
@ -36,21 +36,21 @@ namespace sd {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) { void Reduction3Loops<X,Z>::innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop); Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop);
#endif #endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) { void Reduction3Loops<X, Y>::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS);
#endif #endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) { void Reduction3Loops<X, Y>::wrapperAll(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS);
#endif #endif

View File

@ -28,7 +28,7 @@ namespace sd {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) { void Reduction3Loops<X,Z>::innerloopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop); Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop);
#endif #endif
@ -36,21 +36,21 @@ namespace sd {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) { void Reduction3Loops<X,Z>::innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop); Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop);
#endif #endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) { void Reduction3Loops<X, Y>::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS);
#endif #endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) { void Reduction3Loops<X, Y>::wrapperAll(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS);
#endif #endif

View File

@ -28,7 +28,7 @@ namespace sd {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) { void Reduction3Loops<X,Z>::innerloopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop); Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop);
#endif #endif
@ -36,21 +36,21 @@ namespace sd {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) { void Reduction3Loops<X,Z>::innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop); Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop);
#endif #endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) { void Reduction3Loops<X, Y>::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS);
#endif #endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) { void Reduction3Loops<X, Y>::wrapperAll(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS);
#endif #endif

View File

@ -28,7 +28,7 @@ namespace sd {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) { void Reduction3Loops<X,Z>::innerloopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop); Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop);
#endif #endif
@ -36,21 +36,21 @@ namespace sd {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) { void Reduction3Loops<X,Z>::innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop); Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop);
#endif #endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) { void Reduction3Loops<X, Y>::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS);
#endif #endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) { void Reduction3Loops<X, Y>::wrapperAll(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS);
#endif #endif

View File

@ -26,17 +26,18 @@ namespace sd {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void ReductionBoolLoops<X, Z>::innerloopReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) { void ReductionBoolLoops<X, Z>::innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop); ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif #endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void ReductionBoolLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z, void ReductionBoolLoops<X, Y>::wrapper(const int opNum,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo, const X *x, const Nd4jLong *xShapeInfo,
Nd4jLong *tadOffsets, Y *z, const Nd4jLong *zShapeInfo,
X *extraParams, int64_t start, int64_t stop) { const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
X *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_BOOL_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_BOOL_OPS);
#endif #endif

View File

@ -28,16 +28,18 @@ namespace sd {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) { void ReductionFloatLoops<X, Z>::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop); ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif #endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z, void ReductionFloatLoops<X, Y>::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo, Y *z, const Nd4jLong *zShapeInfo,
Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) { const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
Y *extraParams,
int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS);
#endif #endif

View File

@ -28,16 +28,19 @@ namespace sd {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) { void ReductionFloatLoops<X, Z>::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop); ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif #endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z, void ReductionFloatLoops<X, Y>::wrapper(const int opNum,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo, const X *x, const Nd4jLong *xShapeInfo,
Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) { Y *z, const Nd4jLong *zShapeInfo,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
Y *extraParams,
int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS);
#endif #endif

View File

@ -28,16 +28,16 @@ namespace sd {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) { void ReductionFloatLoops<X, Z>::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop); ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif #endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z, void ReductionFloatLoops<X, Y>::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, Y *z,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo, const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) { const Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS);
#endif #endif

View File

@ -28,16 +28,16 @@ namespace sd {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) { void ReductionFloatLoops<X, Z>::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop); ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif #endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z, void ReductionFloatLoops<X, Y>::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, Y *z,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo, const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) { const Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS);
#endif #endif

View File

@ -33,16 +33,16 @@ namespace sd {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void ReductionLongLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z *z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) { void ReductionLongLoops<X, Z>::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z *z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop); ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif #endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void ReductionLongLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z, void ReductionLongLoops<X, Y>::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, Y *z,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo, const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets, X *extraParams, int64_t start, int64_t stop) { const Nd4jLong *tadOffsets, X *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_LONG_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_LONG_OPS);
#endif #endif

View File

@ -26,16 +26,16 @@ namespace sd {
template<typename X> template<typename X>
template <typename OpType> template <typename OpType>
void ReductionSameLoops<X>::innerloopReduce(X* x, Nd4jLong* xShapeInfo, X* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) { void ReductionSameLoops<X>::innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
ReductionLoops<X,X,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop); ReductionLoops<X,X,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif #endif
} }
template<typename X> template<typename X>
void ReductionSameLoops<X>::wrapper(const int opNum, X *vx, Nd4jLong *xShapeInfo, X *vz, void ReductionSameLoops<X>::wrapper(const int opNum, const X *vx, const Nd4jLong *xShapeInfo, X *vz,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo, const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets, const Nd4jLong *tadOffsets,
X *vextraParams, int64_t start, int64_t stop) { X *vextraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS #ifndef INLINE_LOOPS
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);

View File

@ -83,40 +83,40 @@ namespace sd {
return _cache[deviceId].count(descriptor) != 0; return _cache[deviceId].count(descriptor) != 0;
} }
Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { Nd4jLong const* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
ShapeDescriptor descriptor(dataType, order, shape, rank); ShapeDescriptor descriptor(dataType, order, shape, rank);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
} }
Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const Nd4jLong* shapeInfo) { Nd4jLong const* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const Nd4jLong* shapeInfo) {
return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast<Nd4jLong*>(shapeInfo))); return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast<Nd4jLong*>(shapeInfo)));
} }
Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const sd::DataType dataType) { Nd4jLong const* ConstantShapeHelper::emptyShapeInfo(const sd::DataType dataType) {
auto descriptor = ShapeDescriptor::emptyDescriptor(dataType); auto descriptor = ShapeDescriptor::emptyDescriptor(dataType);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
} }
Nd4jLong* ConstantShapeHelper::scalarShapeInfo(const sd::DataType dataType) { Nd4jLong const* ConstantShapeHelper::scalarShapeInfo(const sd::DataType dataType) {
auto descriptor = ShapeDescriptor::scalarDescriptor(dataType); auto descriptor = ShapeDescriptor::scalarDescriptor(dataType);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
} }
Nd4jLong* ConstantShapeHelper::vectorShapeInfo(const Nd4jLong length, const sd::DataType dataType) { Nd4jLong const* ConstantShapeHelper::vectorShapeInfo(const Nd4jLong length, const sd::DataType dataType) {
auto descriptor = ShapeDescriptor::vectorDescriptor(length, dataType); auto descriptor = ShapeDescriptor::vectorDescriptor(length, dataType);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
} }
Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const std::vector<Nd4jLong> &shape) { Nd4jLong const* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const std::vector<Nd4jLong> &shape) {
ShapeDescriptor descriptor(dataType, order, shape); ShapeDescriptor descriptor(dataType, order, shape);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
} }
Nd4jLong* ConstantShapeHelper::createShapeInfo(const ShapeDescriptor &descriptor) { Nd4jLong const* ConstantShapeHelper::createShapeInfo(const ShapeDescriptor &descriptor) {
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
} }
Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal) { Nd4jLong const* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal) {
ShapeDescriptor descriptor(shapeInfo); ShapeDescriptor descriptor(shapeInfo);
auto result = createShapeInfo(descriptor); auto result = createShapeInfo(descriptor);
@ -126,7 +126,7 @@ namespace sd {
return result; return result;
} }
Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace) { Nd4jLong const* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace) {
ShapeDescriptor descriptor(shapeInfo); ShapeDescriptor descriptor(shapeInfo);
auto result = createShapeInfo(descriptor); auto result = createShapeInfo(descriptor);
@ -136,7 +136,7 @@ namespace sd {
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector<int> dimensions) { ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector<int>& dimensions) {
Nd4jLong* newShapeInfo = nullptr; Nd4jLong* newShapeInfo = nullptr;
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong); ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong);

View File

@ -268,8 +268,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
const int sharedMem = threadsPerBlock * sizeof(int) * 6 + 128; // 6 = aRank + bRank + cRank const int sharedMem = threadsPerBlock * sizeof(int) * 6 + 128; // 6 = aRank + bRank + cRank
NDArray::prepareSpecialUse({C}, {A, B}); NDArray::prepareSpecialUse({C}, {A, B});
// BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); // BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->specialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES) BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->specialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES)
NDArray::registerSpecialUse({C}, {A, B}); NDArray::registerSpecialUse({C}, {A, B});
auto cudaResult = cudaStreamSynchronize(*stream); auto cudaResult = cudaStreamSynchronize(*stream);
@ -319,23 +319,23 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
// choose appropriate cuda gemm api depending on data types // choose appropriate cuda gemm api depending on data types
if(typeDouble) { if(typeDouble) {
status = cublasDgemm(*handle, transAblas, transBblas, M, N, K, &alpha, (double*)pA->getSpecialBuffer(), lda, (double*)pB->getSpecialBuffer(), ldb, &beta, (double*)pC->getSpecialBuffer(), ldc); status = cublasDgemm(*handle, transAblas, transBblas, M, N, K, &alpha, (double*)pA->specialBuffer(), lda, (double*)pB->specialBuffer(), ldb, &beta, (double*)pC->specialBuffer(), ldc);
} }
else if(typeFloat) { else if(typeFloat) {
float alphaF(alpha), betaF(beta); float alphaF(alpha), betaF(beta);
status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc); status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->specialBuffer(), lda, (float*)pB->specialBuffer(), ldb, &betaF, (float*)pC->specialBuffer(), ldc);
} }
else if(typeHalf) { else if(typeHalf) {
float16 alphaH(alpha), betaH(beta); float16 alphaH(alpha), betaH(beta);
status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc); status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->specialBuffer(), lda, (__half*)pB->specialBuffer(), ldb, &betaH.data, (__half*)pC->specialBuffer(), ldc);
} }
else if(typeIntFloat) { else if(typeIntFloat) {
float alphaF(alpha), betaF(beta); float alphaF(alpha), betaF(beta);
status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_8I, lda, pB->getSpecialBuffer(), CUDA_R_8I, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc); status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->specialBuffer(), CUDA_R_8I, lda, pB->specialBuffer(), CUDA_R_8I, ldb, &betaF, pC->specialBuffer(), CUDA_R_32F, ldc);
} }
else if(typeHalfFloat) { else if(typeHalfFloat) {
float alphaF(alpha), betaF(beta); float alphaF(alpha), betaF(beta);
status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_16F, lda, pB->getSpecialBuffer(), CUDA_R_16F, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc); status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->specialBuffer(), CUDA_R_16F, lda, pB->specialBuffer(), CUDA_R_16F, ldb, &betaF, pC->specialBuffer(), CUDA_R_32F, ldc);
} }
if (status != CUBLAS_STATUS_SUCCESS) if (status != CUBLAS_STATUS_SUCCESS)
@ -365,13 +365,13 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y,
if(A->rankOf() != 2) if(A->rankOf() != 2)
throw std::runtime_error("MmulHelper::mmulMxV cuda: rank of A array is not equal 2 !"); throw std::runtime_error("MmulHelper::mmulMxV cuda: rank of A array is not equal 2 !");
if(!shape::isCommonVector(X->getShapeInfo(), xLenDim)) if(!shape::isCommonVector(X->shapeInfo(), xLenDim))
throw std::runtime_error("MmulHelper::mmulMxV cuda: X array must be vector !"); throw std::runtime_error("MmulHelper::mmulMxV cuda: X array must be vector !");
const auto M = A->sizeAt(0); const auto M = A->sizeAt(0);
const auto N = A->sizeAt(1); const auto N = A->sizeAt(1);
if(Y != nullptr && !shape::isCommonVector(Y->getShapeInfo(), yLenDim)) if(Y != nullptr && !shape::isCommonVector(Y->shapeInfo(), yLenDim))
throw std::runtime_error("MmulHelper::mmulMxV cuda: Y array must be vector !"); throw std::runtime_error("MmulHelper::mmulMxV cuda: Y array must be vector !");
if(X->lengthOf() != N) if(X->lengthOf() != N)
throw std::runtime_error("MmulHelper::mmulMxV cuda: X vector has wrong length !"); throw std::runtime_error("MmulHelper::mmulMxV cuda: X vector has wrong length !");
@ -411,8 +411,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y,
const int blocksPerGrid = (M + threadsPerBlock - 1) / threadsPerBlock; const int blocksPerGrid = (M + threadsPerBlock - 1) / threadsPerBlock;
NDArray::prepareSpecialUse({Y}, {A, X}); NDArray::prepareSpecialUse({Y}, {A, X});
// BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); // BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->specialBuffer(), A->specialShapeInfo(), X->specialBuffer(), X->specialShapeInfo(), Y->specialBuffer(), Y->specialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES) BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->specialBuffer(), A->specialShapeInfo(), X->specialBuffer(), X->specialShapeInfo(), Y->specialBuffer(), Y->specialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES)
NDArray::registerSpecialUse({Y}, {A, X}); NDArray::registerSpecialUse({Y}, {A, X});
auto cudaResult = cudaStreamSynchronize(*stream); auto cudaResult = cudaStreamSynchronize(*stream);
@ -442,11 +442,11 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y,
// choose appropriate cuda gemm api depending on data types // choose appropriate cuda gemm api depending on data types
if(typeDouble) { if(typeDouble) {
status = cublasDgemv(*handle, transAblas, transA ? N : M, transA ? M : N, &alpha, (double*)pA->getSpecialBuffer(), lda, (double*)X->getSpecialBuffer(), incx, &beta, (double*)Y->getSpecialBuffer(), incy); status = cublasDgemv(*handle, transAblas, transA ? N : M, transA ? M : N, &alpha, (double*)pA->specialBuffer(), lda, (double*)X->specialBuffer(), incx, &beta, (double*)Y->specialBuffer(), incy);
} }
else if(typeFloat) { else if(typeFloat) {
float alphaF(alpha), betaF(beta); float alphaF(alpha), betaF(beta);
status = cublasSgemv(*handle, transAblas, transA ? N : M, transA ? M : N, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)X->getSpecialBuffer(), incx, &betaF, (float*)Y->getSpecialBuffer(), incy); status = cublasSgemv(*handle, transAblas, transA ? N : M, transA ? M : N, &alphaF, (float*)pA->specialBuffer(), lda, (float*)X->specialBuffer(), incx, &betaF, (float*)Y->specialBuffer(), incy);
} }
if (status != CUBLAS_STATUS_SUCCESS) if (status != CUBLAS_STATUS_SUCCESS)
@ -471,9 +471,9 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con
int xLenDim(0), yLenDim(0); int xLenDim(0), yLenDim(0);
if(!shape::isCommonVector(X->getShapeInfo(), xLenDim)) if(!shape::isCommonVector(X->shapeInfo(), xLenDim))
throw std::runtime_error("MmulHelper::dot cuda: X array must be vector !"); throw std::runtime_error("MmulHelper::dot cuda: X array must be vector !");
if(!shape::isCommonVector(Y->getShapeInfo(), yLenDim)) if(!shape::isCommonVector(Y->shapeInfo(), yLenDim))
throw std::runtime_error("MmulHelper::dot cuda: Y array must be vector !"); throw std::runtime_error("MmulHelper::dot cuda: Y array must be vector !");
if(Z != nullptr && !Z->isScalar()) if(Z != nullptr && !Z->isScalar())
throw std::runtime_error("MmulHelper::dot cuda: Z array must be scalar !"); throw std::runtime_error("MmulHelper::dot cuda: Z array must be scalar !");
@ -506,8 +506,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con
NDArray::prepareSpecialUse({Z}, {X, Y}); NDArray::prepareSpecialUse({Z}, {X, Y});
//BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->specialBuffer(), incx, Y->specialBuffer(), incy, beta, Z->specialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES) BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->specialBuffer(), incx, Y->specialBuffer(), incy, beta, Z->specialBuffer()), NUMERIC_TYPES)
auto cudaResult = cudaStreamSynchronize(*stream); auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult); if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult);
@ -667,8 +667,8 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con
cBatchDims = reinterpret_cast<int*>(manager.replicatePointer(ShapeUtils::evalDimsToExclude(cRank, {cMaxis, cNaxis}).data(), (cRank - 2) * sizeof(int))); cBatchDims = reinterpret_cast<int*>(manager.replicatePointer(ShapeUtils::evalDimsToExclude(cRank, {cMaxis, cNaxis}).data(), (cRank - 2) * sizeof(int)));
NDArray::prepareSpecialUse({C}, {A, B}); NDArray::prepareSpecialUse({C}, {A, B});
// BUILD_TRIPLE_SELECTOR(A->dataType(), b->dataType(), C->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); // BUILD_TRIPLE_SELECTOR(A->dataType(), b->dataType(), C->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, A->getContext()->getCudaStream(), A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->specialShapeInfo(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(A->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, sharedMem, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), aBatchDims, bBatchDims, cBatchDims, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES) BUILD_SINGLE_SELECTOR_THRICE(A->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, sharedMem, A->getContext()->getCudaStream(), A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->specialShapeInfo(), aBatchDims, bBatchDims, cBatchDims, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES)
NDArray::registerSpecialUse({C}, {A, B}); NDArray::registerSpecialUse({C}, {A, B});
manager.synchronize(); manager.synchronize();
@ -797,13 +797,13 @@ NDArray* MmulHelper::mmulNxNold1(const NDArray* A, const NDArray* B, NDArray* C,
// multiplication // multiplication
const std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1}); const std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1});
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->getShapeInfo(), dimsToExclude); const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->shapeInfo(), dimsToExclude);
std::vector<Nd4jLong> idxRanges(2 * C->rankOf()); std::vector<Nd4jLong> idxRanges(2 * C->rankOf());
// #pragma omp parallel for schedule(guided) firstprivate(idxRanges) // #pragma omp parallel for schedule(guided) firstprivate(idxRanges)
for(Nd4jLong i = 0; i < numOfSubArrs; ++i) { for(Nd4jLong i = 0; i < numOfSubArrs; ++i) {
ShapeUtils::evalIdxRangesForSubArr(i, C->getShapeInfo(), dimsToExclude, idxRanges.data()); ShapeUtils::evalIdxRangesForSubArr(i, C->shapeInfo(), dimsToExclude, idxRanges.data());
NDArray cSubArr = (*C)(idxRanges); NDArray cSubArr = (*C)(idxRanges);
if(aRank > bRank) { if(aRank > bRank) {
@ -944,18 +944,18 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C,
std::vector<void*> aSubArrs(bS), bSubArrs(bS), cSubArrs(bS); std::vector<void*> aSubArrs(bS), bSubArrs(bS), cSubArrs(bS);
if(aRank > 2) if(aRank > 2)
shape::calcSubArrsShapeInfoAndOffsets(pA->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data()); shape::calcSubArrsShapeInfoAndOffsets(pA->shapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
for (int i = 0; i < bS; ++i) for (int i = 0; i < bS; ++i)
aSubArrs[i] = aRank == 2 ? pA->getSpecialBuffer() : pA->getSpecialBuffer() + subArrOffsets[i] * pA->sizeOfT(); aSubArrs[i] = aRank == 2 ? pA->specialBuffer() : pA->specialBuffer() + subArrOffsets[i] * pA->sizeOfT();
if(bRank > 2) if(bRank > 2)
shape::calcSubArrsShapeInfoAndOffsets(pB->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data()); shape::calcSubArrsShapeInfoAndOffsets(pB->shapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
for (int i = 0; i < bS; ++i) for (int i = 0; i < bS; ++i)
bSubArrs[i] = bRank == 2 ? pB->getSpecialBuffer() : pB->getSpecialBuffer() + subArrOffsets[i] * pB->sizeOfT(); bSubArrs[i] = bRank == 2 ? pB->specialBuffer() : pB->specialBuffer() + subArrOffsets[i] * pB->sizeOfT();
shape::calcSubArrsShapeInfoAndOffsets(pC->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data()); shape::calcSubArrsShapeInfoAndOffsets(pC->shapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
for (int i = 0; i < bS; ++i) for (int i = 0; i < bS; ++i)
cSubArrs[i] = pC->getSpecialBuffer() + subArrOffsets[i] * pC->sizeOfT(); cSubArrs[i] = pC->specialBuffer() + subArrOffsets[i] * pC->sizeOfT();
PointersManager manager(A->getContext(), "mmulNxN"); PointersManager manager(A->getContext(), "mmulNxN");
@ -1011,7 +1011,7 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C,
for(Nd4jLong i = 0; i < bS; ++i) { for(Nd4jLong i = 0; i < bS; ++i) {
ShapeUtils::evalIdxRangesForSubArr(i, pC->getShapeInfo(), dimsToExclude, idxRanges.data()); ShapeUtils::evalIdxRangesForSubArr(i, pC->shapeInfo(), dimsToExclude, idxRanges.data());
NDArray cSubArr = (*pC)(idxRanges); NDArray cSubArr = (*pC)(idxRanges);
if(aRank > bRank) { if(aRank > bRank) {

View File

@ -91,7 +91,7 @@ void sd::MmulHelper::tensorDot(const sd::NDArray* a, const sd::NDArray* b, sd::N
mmul(aPR, bPR, cPR, 1.0, 0.0); mmul(aPR, bPR, cPR, 1.0, 0.0);
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer() if(cPR->buffer() != cP->buffer() || cPR->specialBuffer() != cP->specialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->buffer()
cP->assign(cPR); cP->assign(cPR);
if(aP != aPR) if(aP != aPR)
@ -150,7 +150,7 @@ void sd::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c, c
// check whether new buffer allocation was happened for c array // check whether new buffer allocation was happened for c array
if(!whatToDoWithC.empty()) { if(!whatToDoWithC.empty()) {
for(int i = cArrs.size()-1; i > 0; --i) { for(int i = cArrs.size()-1; i > 0; --i) {
if(cArrs[i]->getBuffer() != cArrs[i-1]->getBuffer() || cArrs[i]->getSpecialBuffer() != cArrs[i-1]->getSpecialBuffer()) if(cArrs[i]->buffer() != cArrs[i-1]->buffer() || cArrs[i]->specialBuffer() != cArrs[i-1]->specialBuffer())
cArrs[i-1]->assign(cArrs[i]); cArrs[i-1]->assign(cArrs[i]);
delete cArrs[i]; delete cArrs[i];
} }
@ -203,8 +203,8 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
int lenDim; int lenDim;
const int aRank = A->rankOf(); const int aRank = A->rankOf();
const int bRank = B->rankOf(); const int bRank = B->rankOf();
const bool isAVector = shape::isCommonVector(A->getShapeInfo(), lenDim); const bool isAVector = shape::isCommonVector(A->shapeInfo(), lenDim);
const bool isBVector = shape::isCommonVector(B->getShapeInfo(), lenDim); const bool isBVector = shape::isCommonVector(B->shapeInfo(), lenDim);
// dot product of 2 vectors // dot product of 2 vectors
if(isAVector && isBVector && (aRank != 2 || aRank == 2 && (A->isSameShape(B) || bRank == 1 && A->sizeAt(1) == 1))) // (1x1x1 * 1x1) or (1x4 * 1*4) or (4x1 * 4x1) or (4x1 * 4) if(isAVector && isBVector && (aRank != 2 || aRank == 2 && (A->isSameShape(B) || bRank == 1 && A->sizeAt(1) == 1))) // (1x1x1 * 1x1) or (1x4 * 1*4) or (4x1 * 4x1) or (4x1 * 4)
@ -243,7 +243,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
int xRank = x->rankOf(); int xRank = x->rankOf();
int yRank = y->rankOf(); int yRank = y->rankOf();
auto outShape = ShapeUtils::evalShapeForMatmul(x->getShapeInfo(), y->getShapeInfo(), transX, transY); auto outShape = ShapeUtils::evalShapeForMatmul(x->shapeInfo(), y->shapeInfo(), transX, transY);
if(!z->isSameShape(outShape)) { if(!z->isSameShape(outShape)) {
nd4j_printf("NDArrayFactory::matmul static method: input shape of output array is wrong, actual is %s and expected is %s ! \n", ShapeUtils::shapeAsString(z).c_str(), ShapeUtils::shapeAsString(outShape).c_str()); nd4j_printf("NDArrayFactory::matmul static method: input shape of output array is wrong, actual is %s and expected is %s ! \n", ShapeUtils::shapeAsString(z).c_str(), ShapeUtils::shapeAsString(outShape).c_str());
throw std::invalid_argument(""); throw std::invalid_argument("");
@ -285,7 +285,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
for(int i = 0; i < batchRank; ++i) for(int i = 0; i < batchRank; ++i)
dimsToExclude[i] = i; dimsToExclude[i] = i;
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(xT->getShapeInfo(), dimsToExclude); const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(xT->shapeInfo(), dimsToExclude);
//PRAGMA_OMP_PARALLEL_FOR //PRAGMA_OMP_PARALLEL_FOR
for(Nd4jLong i = 0; i < numOfSubArrs; ++i) { for(Nd4jLong i = 0; i < numOfSubArrs; ++i) {

View File

@ -118,13 +118,13 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const NDArray* a, const NDArray* b, const std::vector<int>& axesA, const std::vector<int>& axesB, std::vector<int>& permutAt, std::vector<int>& permutBt, std::vector<Nd4jLong>& shapeAt, std::vector<Nd4jLong>& shapeBt) { std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const NDArray* a, const NDArray* b, const std::vector<int>& axesA, const std::vector<int>& axesB, std::vector<int>& permutAt, std::vector<int>& permutBt, std::vector<Nd4jLong>& shapeAt, std::vector<Nd4jLong>& shapeBt) {
return evalShapeForTensorDot(a->getShapeInfo(), b->getShapeInfo(), axesA, axesB, permutAt, permutBt, shapeAt, shapeBt); return evalShapeForTensorDot(a->shapeInfo(), b->shapeInfo(), axesA, axesB, permutAt, permutBt, shapeAt, shapeBt);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// evaluate output shape for reduce operation when input shape is empty // evaluate output shape for reduce operation when input shape is empty
Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimsToExclude, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, sd::memory::Workspace* workspace) { const Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimsToExclude, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, sd::memory::Workspace* workspace) {
if (dimsToExclude.size() == 0) { // return copy of input shape if (dimsToExclude.size() == 0) { // return copy of input shape
Nd4jLong* outShapeInfo = ShapeBuilders::copyShapeInfoAndType(shapeInfo, dataType, true, workspace); Nd4jLong* outShapeInfo = ShapeBuilders::copyShapeInfoAndType(shapeInfo, dataType, true, workspace);
@ -171,22 +171,22 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty(const char order, std::vector<int
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
} }
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimsToExclude, arr, arr.dataType(), keepDims, supportOldShapes, workspace); return evalReduceShapeInfo(order, dimsToExclude, arr, arr.dataType(), keepDims, supportOldShapes, workspace);
} }
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimsToExclude, shapeInfo, ArrayOptions::dataType(shapeInfo), keepDims, supportOldShapes, workspace); return evalReduceShapeInfo(order, dimsToExclude, shapeInfo, ArrayOptions::dataType(shapeInfo), keepDims, supportOldShapes, workspace);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const sd::DataType dataType, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const sd::DataType dataType, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimsToExclude, arr.getShapeInfo(), dataType, keepDims, supportOldShapes, workspace); return evalReduceShapeInfo(order, dimsToExclude, arr.shapeInfo(), dataType, keepDims, supportOldShapes, workspace);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// evaluate shape resulting from reduce operation // evaluate shape resulting from reduce operation
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) {
if(ArrayOptions::arrayType(shapeInfo) == ArrayType::EMPTY) if(ArrayOptions::arrayType(shapeInfo) == ArrayType::EMPTY)
return ShapeUtils::evalReduceShapeInfoEmpty(order, dimsToExclude, shapeInfo, dataType, keepDims, workspace); return ShapeUtils::evalReduceShapeInfoEmpty(order, dimsToExclude, shapeInfo, dataType, keepDims, workspace);
@ -314,39 +314,39 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of permuted array // evaluate shapeInfo of permuted array
Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides) { const Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides) {
if (!arr.nonNull()) if (!arr.nonNull())
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!"); throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!");
if (rank != arr.rankOf()) if (rank != arr.rankOf())
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!"); throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!");
auto shapeInfoLength = shape::shapeInfoLength(rank); auto shapeInfoLength = shape::shapeInfoLength(rank);
// allocate memory for new array - shapeInfo // allocate memory for new array - shapeInfo
Nd4jLong *shapeInfoNew = nullptr; Nd4jLong *shapeInfoNew = nullptr;
ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong); ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong);
// copy arr _shapeInfo into new array // copy arr _shapeInfo into new array
memcpy(shapeInfoNew, arr.getShapeInfo(), shape::shapeInfoByteLength(rank)); memcpy(shapeInfoNew, arr.shapeInfo(), shape::shapeInfoByteLength(rank));
// perform buffer permutation // perform buffer permutation
shape::doPermuteShapeInfo(shapeInfoNew, dimensions, arr.lengthOf()); shape::doPermuteShapeInfo(shapeInfoNew, dimensions, arr.lengthOf());
if(setContigStrides) if(setContigStrides)
shape::updateStrides(shapeInfoNew, arr.ordering()); shape::updateStrides(shapeInfoNew, arr.ordering());
ShapeDescriptor descriptor(shapeInfoNew); ShapeDescriptor descriptor(shapeInfoNew);
RELEASE(shapeInfoNew, workspace); RELEASE(shapeInfoNew, workspace);
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of permuted array // evaluate shapeInfo of permuted array
Nd4jLong* ShapeUtils::evalPermShapeInfo(const Nd4jLong *dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace) { const Nd4jLong* ShapeUtils::evalPermShapeInfo(const Nd4jLong *dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace) {
std::vector<int> dims(dimensions, dimensions + rank); std::vector<int> dims(dimensions, dimensions + rank);
return evalPermShapeInfo(dims.data(), rank, arr, workspace); return evalPermShapeInfo(dims.data(), rank, arr, workspace);
@ -354,7 +354,7 @@ Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, c
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of transposed array // evaluate shapeInfo of transposed array
Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides) { const Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides) {
int rank = arr.rankOf(); int rank = arr.rankOf();
std::vector<int> dimensions(rank); std::vector<int> dimensions(rank);
@ -414,10 +414,10 @@ std::vector<int> ShapeUtils::evalDimsToExclude(const int rank, const std::vector
// check whether 2 arrays have mutually broadcastable shapes // check whether 2 arrays have mutually broadcastable shapes
// shape comparison starts from the end // shape comparison starts from the end
bool ShapeUtils::areShapesBroadcastable(const NDArray &arr1, const NDArray &arr2) { bool ShapeUtils::areShapesBroadcastable(const NDArray &arr1, const NDArray &arr2) {
return areShapesBroadcastable(arr1.getShapeInfo(), arr2.getShapeInfo()); return areShapesBroadcastable(arr1.shapeInfo(), arr2.shapeInfo());
} }
bool ShapeUtils::areShapesBroadcastable(Nd4jLong *shapeInfo1, Nd4jLong *shapeInfo2) { bool ShapeUtils::areShapesBroadcastable(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2) {
int minRank = shape::rank(shapeInfo1) < shape::rank(shapeInfo2) ? shape::rank(shapeInfo1) : shape::rank(shapeInfo2); int minRank = shape::rank(shapeInfo1) < shape::rank(shapeInfo2) ? shape::rank(shapeInfo1) : shape::rank(shapeInfo2);
for (int i = -1; i >= -minRank; --i) for (int i = -1; i >= -minRank; --i)
@ -427,177 +427,177 @@ bool ShapeUtils::areShapesBroadcastable(Nd4jLong *shapeInfo1, Nd4jLong *shapeInf
return true; return true;
} }
bool ShapeUtils::areShapesBroadcastable(const std::vector<Nd4jLong>& shape1, const std::vector<Nd4jLong>& shape2) { bool ShapeUtils::areShapesBroadcastable(const std::vector<Nd4jLong>& shape1, const std::vector<Nd4jLong>& shape2) {
const auto rank1 = shape1.size(); const auto rank1 = shape1.size();
const auto rank2 = shape2.size(); const auto rank2 = shape2.size();
const int minRank = rank1 < rank2 ? rank1 : rank2; const int minRank = rank1 < rank2 ? rank1 : rank2;
for (int i = 1; i <= minRank; ++i) for (int i = 1; i <= minRank; ++i)
if (shape1[rank1-i] != shape2[rank2-i] && shape1[rank1-i] != 1 && shape2[rank2-i] != 1) if (shape1[rank1-i] != shape2[rank2-i] && shape1[rank1-i] != 1 && shape2[rank2-i] != 1)
return false;
return true;
}
//////////////////////////////////////////////////////////////////////////
// check the possibility of broadcast operation, if true then return shapeInfo of resulting array
// if evalMinMax == false the array with larger rank has to be passed as first argument
bool ShapeUtils::evalBroadcastShapeInfo(const NDArray &max, const NDArray &min, const bool evalMinMax, const Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace) {
return evalBroadcastShapeInfo(max.shapeInfo(), min.shapeInfo(), evalMinMax, resultShapeInfo, workspace);
}
bool ShapeUtils::evalBroadcastShapeInfo(const Nd4jLong *max, const Nd4jLong *min, const bool evalMinMax, const Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace) {
// check whether broadcast operation is possible for input arrays
if(!areShapesBroadcastable(max, min))
return false; return false;
return true; auto maxShapeInfo = max; //max.shapeInfo();
} auto minShapeInfo = min; //min.shapeInfo();
////////////////////////////////////////////////////////////////////////// if(evalMinMax && (shape::rank(max) < shape::rank(min))) {
// check the possibility of broadcast operation, if true then return shapeInfo of resulting array maxShapeInfo = min;
// if evalMinMax == false the array with larger rank has to be passed as first argument minShapeInfo = max;
bool ShapeUtils::evalBroadcastShapeInfo(const NDArray &max, const NDArray &min, const bool evalMinMax, Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace) { }
return evalBroadcastShapeInfo(max.getShapeInfo(), min.getShapeInfo(), evalMinMax, resultShapeInfo, workspace);
}
bool ShapeUtils::evalBroadcastShapeInfo(Nd4jLong *max, Nd4jLong *min, const bool evalMinMax, Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace) { const auto maxRank = shape::rank(maxShapeInfo);
const auto minRank = shape::rank(minShapeInfo);
// check whether broadcast operation is possible for input arrays // evaluate shapeInfo for resulting array
if(!areShapesBroadcastable(max, min)) if(resultShapeInfo != nullptr)
return false; throw std::runtime_error("std::runtime_error(ShapeUtils::evalBroadcastShapeInfo method: the input pointer on shapeInfo must be empty (=nullptr) !");
auto maxShapeInfo = max; //max.getShapeInfo(); Nd4jLong *tmpShapeInfo = nullptr;
auto minShapeInfo = min; //min.getShapeInfo(); ALLOCATE(tmpShapeInfo, workspace, shape::shapeInfoLength(maxRank), Nd4jLong);
if(evalMinMax && (shape::rank(max) < shape::rank(min))) { // FIXME: get rid of memcpy here
maxShapeInfo = min; memcpy(tmpShapeInfo, maxShapeInfo, shape::shapeInfoByteLength(maxRank));
minShapeInfo = max; for (int i = 0; i < minRank; ++i)
if((maxShapeInfo[maxRank-i] != 0 && maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i]) || minShapeInfo[minRank-i] == 0)
tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i];
ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo));
if (shape::isEmpty(max) || shape::isEmpty(min)) {
ArrayOptions::setPropertyBit(tmpShapeInfo, ARRAY_EMPTY);
memset(shape::stride(tmpShapeInfo), 0, shape::rank(tmpShapeInfo) * sizeof(Nd4jLong));
}
ShapeDescriptor descriptor(tmpShapeInfo);
RELEASE(tmpShapeInfo, workspace);
resultShapeInfo = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
return true;
} }
const auto maxRank = shape::rank(maxShapeInfo); //////////////////////////////////////////////////////////////////////////
const auto minRank = shape::rank(minShapeInfo); // check the possibility of broadcast operation for set of arrays, if true then return resulting broadcasted shapeInfo
bool ShapeUtils::evalCommonBroadcastShapeInfo(const std::vector<const NDArray*>& arrays, Nd4jLong*& resultShapeInfo, memory::Workspace* workspace) {
// evaluate shapeInfo for resulting array if(resultShapeInfo != nullptr)
if(resultShapeInfo != nullptr) throw std::runtime_error("ShapeUtils::evalCommonBroadcastShapeInfo method: the input pointer on shapeInfo must be empty (=nullptr) !");
throw std::runtime_error("std::runtime_error(ShapeUtils::evalBroadcastShapeInfo method: the input pointer on shapeInfo must be empty (=nullptr) !");
Nd4jLong *tmpShapeInfo = nullptr; int size = arrays.size();
ALLOCATE(tmpShapeInfo, workspace, shape::shapeInfoLength(maxRank), Nd4jLong); int maxRank = arrays[size - 1]->rankOf();
// FIXME: get rid of memcpy here for(int i = 0; i < size - 1; ++i) {
memcpy(tmpShapeInfo, maxShapeInfo, shape::shapeInfoByteLength(maxRank)); if(arrays[i]->rankOf() > maxRank)
for (int i = 0; i < minRank; ++i) maxRank = arrays[i]->rankOf();
if((maxShapeInfo[maxRank-i] != 0 && maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i]) || minShapeInfo[minRank-i] == 0) for(int j = i + 1; j < size; ++j)
tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i]; if(!areShapesBroadcastable(*arrays[i], *arrays[j]))
return false;
}
ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo)); Nd4jLong *tmpShapeInfo = nullptr;
ALLOCATE(tmpShapeInfo, workspace, shape::shapeInfoLength(maxRank), Nd4jLong);
memset(tmpShapeInfo, 0, shape::shapeInfoByteLength(maxRank));
tmpShapeInfo[0] = maxRank;
if (shape::isEmpty(max) || shape::isEmpty(min)) { for(const auto& item : arrays ) {
ArrayOptions::setPropertyBit(tmpShapeInfo, ARRAY_EMPTY); for(int i = -1; i >= -item->rankOf(); --i)
memset(shape::stride(tmpShapeInfo), 0, shape::rank(tmpShapeInfo) * sizeof(Nd4jLong)); if(tmpShapeInfo[i + 1 + maxRank] < item->sizeAt(i))
tmpShapeInfo[i + 1 + maxRank] = item->sizeAt(i);
}
shape::updateStrides(tmpShapeInfo, arrays[0]->ordering());
ArrayOptions::setDataType(tmpShapeInfo, arrays[0]->dataType());
ShapeDescriptor descriptor(tmpShapeInfo);
RELEASE(tmpShapeInfo, workspace);
resultShapeInfo = const_cast<Nd4jLong*>(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor));
return true;
} }
ShapeDescriptor descriptor(tmpShapeInfo);
RELEASE(tmpShapeInfo, workspace);
resultShapeInfo = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
return true; //////////////////////////////////////////////////////////////////////////
} // return sorted vector of dimensions common (same) for two arrays, dimensions values corresponds to array with bigger rank
// for example if arr1{2,7}, arr2{2,5,4,7} then vector = {0,3}
std::vector<int> ShapeUtils::getDimsWithSameShape(const NDArray& arr1, const NDArray& arr2) {
////////////////////////////////////////////////////////////////////////// const NDArray *min, *max;
// check the possibility of broadcast operation for set of arrays, if true then return resulting broadcasted shapeInfo
bool ShapeUtils::evalCommonBroadcastShapeInfo(const std::vector<const NDArray*>& arrays, Nd4jLong*& resultShapeInfo, memory::Workspace* workspace) {
if(resultShapeInfo != nullptr) if(arr1.rankOf() >= arr2.rankOf()) {
throw std::runtime_error("ShapeUtils::evalCommonBroadcastShapeInfo method: the input pointer on shapeInfo must be empty (=nullptr) !"); max = &arr1;
min = &arr2;
}
else {
max = &arr2;
min = &arr1;
}
int size = arrays.size(); const int rankDiff = max->rankOf() - min->rankOf();
int maxRank = arrays[size - 1]->rankOf();
for(int i = 0; i < size - 1; ++i) { std::vector<int> dims;
if(arrays[i]->rankOf() > maxRank)
maxRank = arrays[i]->rankOf(); for (int i = 0; i < min->rankOf(); ++i)
for(int j = i + 1; j < size; ++j) if (min->sizeAt(i) == max->sizeAt(rankDiff + i))
if(!areShapesBroadcastable(*arrays[i], *arrays[j])) dims.emplace_back(rankDiff + i);
return false;
return dims;
} }
Nd4jLong *tmpShapeInfo = nullptr; //////////////////////////////////////////////////////////////////////////
ALLOCATE(tmpShapeInfo, workspace, shape::shapeInfoLength(maxRank), Nd4jLong); // evaluate shapeInfo for resulting array from tile operation
memset(tmpShapeInfo, 0, shape::shapeInfoByteLength(maxRank)); const Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd4jLong>& reps, sd::memory::Workspace* workspace) {
tmpShapeInfo[0] = maxRank; // check whether reps contains at least one zero (then throw exception) or whether all elements in reps are unities (then simply reshape or do nothing)
int repsSize = reps.size();
Nd4jLong product = 1;
for(const auto& item : reps)
product *= item;
if(product == 0)
throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !");
for(const auto& item : arrays ) { int rankOld = arr.rankOf();
for(int i = -1; i >= -item->rankOf(); --i) int diff = rankOld - repsSize;
if(tmpShapeInfo[i + 1 + maxRank] < item->sizeAt(i))
tmpShapeInfo[i + 1 + maxRank] = item->sizeAt(i); // evaluate new shapeInfo
Nd4jLong* newShapeInfo = nullptr;
if(diff < 0) {
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(repsSize), Nd4jLong);
newShapeInfo[0] = repsSize; // set new rank
for(int i=1; i <= -diff; ++i)
newShapeInfo[i] = 1; // set unities to be new dimensions at left-hand side of newShapeInfo shape place
memcpy(newShapeInfo + 1 - diff, arr.shapeInfo() + 1, rankOld*sizeof(Nd4jLong)); // copy old dimensions to the right-hand side of newShapeInfo shape place
for(int i=1; i <= repsSize; ++i)
newShapeInfo[i] *= reps[i - 1]; // set new shape by multiplying old dimensions by corresponding numbers from reps
}
else {
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rankOld), Nd4jLong);
memcpy(newShapeInfo, arr.shapeInfo(), shape::shapeInfoByteLength(rankOld)); // copy all elements of _shapeInfo to newShapeInfo
for(int i=1; i <= repsSize; ++i)
newShapeInfo[rankOld + 1 - i] *= reps[repsSize - i]; // set new shape by multiplying old dimensions by corresponding numbers from reps
}
shape::updateStrides(newShapeInfo, arr.ordering());
ArrayOptions::setDataType(newShapeInfo, arr.dataType());
ShapeDescriptor descriptor(newShapeInfo);
RELEASE(newShapeInfo, workspace);
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
} }
shape::updateStrides(tmpShapeInfo, arrays[0]->ordering()); std::vector<Nd4jLong> ShapeUtils::pullShapeFromShapeInfo(const Nd4jLong *shapeInfo) {
ArrayOptions::setDataType(tmpShapeInfo, arrays[0]->dataType());
ShapeDescriptor descriptor(tmpShapeInfo);
RELEASE(tmpShapeInfo, workspace);
resultShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(descriptor);
return true;
}
//////////////////////////////////////////////////////////////////////////
// return sorted vector of dimensions common (same) for two arrays, dimensions values corresponds to array with bigger rank
// for example if arr1{2,7}, arr2{2,5,4,7} then vector = {0,3}
std::vector<int> ShapeUtils::getDimsWithSameShape(const NDArray& arr1, const NDArray& arr2) {
const NDArray *min, *max;
if(arr1.rankOf() >= arr2.rankOf()) {
max = &arr1;
min = &arr2;
}
else {
max = &arr2;
min = &arr1;
}
const int rankDiff = max->rankOf() - min->rankOf();
std::vector<int> dims;
for (int i = 0; i < min->rankOf(); ++i)
if (min->sizeAt(i) == max->sizeAt(rankDiff + i))
dims.emplace_back(rankDiff + i);
return dims;
}
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo for resulting array from tile operation
Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd4jLong>& reps, sd::memory::Workspace* workspace) {
// check whether reps contains at least one zero (then throw exception) or whether all elements in reps are unities (then simply reshape or do nothing)
int repsSize = reps.size();
Nd4jLong product = 1;
for(const auto& item : reps)
product *= item;
if(product == 0)
throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !");
int rankOld = arr.rankOf();
int diff = rankOld - repsSize;
// evaluate new shapeInfo
Nd4jLong* newShapeInfo = nullptr;
if(diff < 0) {
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(repsSize), Nd4jLong);
newShapeInfo[0] = repsSize; // set new rank
for(int i=1; i <= -diff; ++i)
newShapeInfo[i] = 1; // set unities to be new dimensions at left-hand side of newShapeInfo shape place
memcpy(newShapeInfo + 1 - diff, arr.getShapeInfo() + 1, rankOld*sizeof(Nd4jLong)); // copy old dimensions to the right-hand side of newShapeInfo shape place
for(int i=1; i <= repsSize; ++i)
newShapeInfo[i] *= reps[i - 1]; // set new shape by multiplying old dimensions by corresponding numbers from reps
}
else {
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rankOld), Nd4jLong);
memcpy(newShapeInfo, arr.getShapeInfo(), shape::shapeInfoByteLength(rankOld)); // copy all elements of _shapeInfo to newShapeInfo
for(int i=1; i <= repsSize; ++i)
newShapeInfo[rankOld + 1 - i] *= reps[repsSize - i]; // set new shape by multiplying old dimensions by corresponding numbers from reps
}
shape::updateStrides(newShapeInfo, arr.ordering());
ArrayOptions::setDataType(newShapeInfo, arr.dataType());
ShapeDescriptor descriptor(newShapeInfo);
RELEASE(newShapeInfo, workspace);
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
std::vector<Nd4jLong> ShapeUtils::pullShapeFromShapeInfo(Nd4jLong *shapeInfo) {
std::vector<Nd4jLong> shape(shape::rank(shapeInfo)); std::vector<Nd4jLong> shape(shape::rank(shapeInfo));
int shapeSize = shape.size(); int shapeSize = shape.size();
@ -624,7 +624,7 @@ Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd
std::string ShapeUtils::strideAsString(const NDArray* array) { std::string ShapeUtils::strideAsString(const NDArray* array) {
std::string result; std::string result;
auto shapeBuffer = array->getShapeInfo(); //Nd4jLong* auto shapeBuffer = array->shapeInfo(); //Nd4jLong*
int rank = (int)*shapeBuffer; int rank = (int)*shapeBuffer;
result.append("["); result.append("[");
for (int e = 0; e < rank; e++) { for (int e = 0; e < rank; e++) {
@ -724,31 +724,31 @@ std::vector<Nd4jLong> ShapeUtils::shapeAsVector(const Nd4jLong* shapeInfo) {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal // evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
Nd4jLong* ShapeUtils::evalDiagShapeInfo(const Nd4jLong* shapeInfoConst, sd::memory::Workspace* workspace){ const Nd4jLong* ShapeUtils::evalDiagShapeInfo(const Nd4jLong* shapeInfoConst, sd::memory::Workspace* workspace){
auto shapeInfo = const_cast<Nd4jLong*>(shapeInfoConst); auto shapeInfo = const_cast<Nd4jLong*>(shapeInfoConst);
const auto rank = shape::rank(shapeInfo); const auto rank = shape::rank(shapeInfo);
Nd4jLong* outputShapeInfo = nullptr; Nd4jLong* outputShapeInfo = nullptr;
if(shape::isVector(shapeInfo) || shape::isScalar(shapeInfo)) { if(shape::isVector(shapeInfo) || shape::isScalar(shapeInfo)) {
ALLOCATE(outputShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); ALLOCATE(outputShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong);
outputShapeInfo[0] = 2; outputShapeInfo[0] = 2;
outputShapeInfo[1] = outputShapeInfo[2] = shape::length(shapeInfo); outputShapeInfo[1] = outputShapeInfo[2] = shape::length(shapeInfo);
}
else {
ALLOCATE(outputShapeInfo, workspace, shape::shapeInfoLength(2*rank), Nd4jLong);
outputShapeInfo[0] = 2*rank;
for(int i = 1; i <= rank; ++i)
outputShapeInfo[i] = outputShapeInfo[i + rank] = shapeInfo[i];
}
ShapeUtils::updateStridesAndType(outputShapeInfo, shapeInfo, shape::order(shapeInfo));
auto result = ConstantShapeHelper::getInstance()->createShapeInfo(outputShapeInfo);
RELEASE(outputShapeInfo, workspace);
return result;
} }
else {
ALLOCATE(outputShapeInfo, workspace, shape::shapeInfoLength(2*rank), Nd4jLong);
outputShapeInfo[0] = 2*rank;
for(int i = 1; i <= rank; ++i)
outputShapeInfo[i] = outputShapeInfo[i + rank] = shapeInfo[i];
}
ShapeUtils::updateStridesAndType(outputShapeInfo, shapeInfo, shape::order(shapeInfo));
auto result = ConstantShapeHelper::getInstance()->createShapeInfo(outputShapeInfo);
RELEASE(outputShapeInfo, workspace);
return result;
}
std::vector<int> ShapeUtils::evalBroadcastBackwardAxis(const Nd4jLong *operandShapeInfo, const Nd4jLong *resultShapeInfo) { std::vector<int> ShapeUtils::evalBroadcastBackwardAxis(const Nd4jLong *operandShapeInfo, const Nd4jLong *resultShapeInfo) {
// rRank >= oRank always !! // rRank >= oRank always !!
@ -765,83 +765,82 @@ std::vector<int> ShapeUtils::evalBroadcastBackwardAxis(const Nd4jLong *operandSh
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
Nd4jLong* ShapeUtils::matrixProductShape(Nd4jLong* theFirstShape, Nd4jLong* theSecondShape, bool shouldTranspondFirst, bool shouldTranspondSecond, sd::DataType dtype, sd::memory::Workspace* workspace) { const Nd4jLong* ShapeUtils::matrixProductShape(const Nd4jLong* theFirstShape, const Nd4jLong* theSecondShape, bool shouldTranspondFirst, bool shouldTranspondSecond, sd::DataType dtype, sd::memory::Workspace* workspace) {
auto inA = theFirstShape;
auto inB = theSecondShape;
Nd4jLong *shape;
ALLOCATE(shape, workspace, shape::shapeInfoLength(2), Nd4jLong);
auto inA = theFirstShape; Nd4jLong* tmpA = ShapeBuilders::copyShapeInfo(inA, true, workspace);
auto inB = theSecondShape; Nd4jLong* tmpB = ShapeBuilders::copyShapeInfo(inB, true, workspace);
Nd4jLong *shape;
ALLOCATE(shape, workspace, shape::shapeInfoLength(2), Nd4jLong);
Nd4jLong* tmpA = ShapeBuilders::copyShapeInfo(inA, true, workspace); if (shouldTranspondFirst)
Nd4jLong* tmpB = ShapeBuilders::copyShapeInfo(inB, true, workspace); shape::transposeInplace(tmpA);
if (shouldTranspondFirst) if (shouldTranspondSecond)
shape::transposeInplace(tmpA); shape::transposeInplace(tmpB);
if (shouldTranspondSecond)
shape::transposeInplace(tmpB);
if (shape::rank(tmpA) == 1 && shape::isMatrix(tmpB)) { if (shape::rank(tmpA) == 1 && shape::isMatrix(tmpB)) {
// special case here // special case here
shape[0] = 1; shape[0] = 1;
shape[1] = tmpB[2];
Nd4jLong *newShape = ShapeBuilders::createShapeInfo(dtype, 'f', 2, shape, workspace);
RELEASE(shape, workspace);
RELEASE(tmpA, workspace);
RELEASE(tmpB, workspace);
return newShape;
} else if (shape::isScalar(tmpA) && shape::isScalar(tmpB)) {
// just scalar vs scalar
shape[0] = 1;
shape[1] = 1;
} else if (shape::isMatrix(tmpA) && shape::isVector(tmpB)) {
// gemv case
if (shape::rank(tmpB) == 2) {
shape[0] = tmpA[1];
shape[1] = tmpB[2]; shape[1] = tmpB[2];
} else { Nd4jLong *newShape = ShapeBuilders::createShapeInfo(dtype, 'f', 2, shape, workspace);
// we have new 1D shape here
auto newShape = ShapeBuilders::createVectorShapeInfo(dtype, tmpA[1], workspace);
RELEASE(shape, workspace); RELEASE(shape, workspace);
RELEASE(tmpA, workspace); RELEASE(tmpA, workspace);
RELEASE(tmpB, workspace); RELEASE(tmpB, workspace);
return newShape; return newShape;
} else if (shape::isScalar(tmpA) && shape::isScalar(tmpB)) {
// just scalar vs scalar
shape[0] = 1;
shape[1] = 1;
} else if (shape::isMatrix(tmpA) && shape::isVector(tmpB)) {
// gemv case
if (shape::rank(tmpB) == 2) {
shape[0] = tmpA[1];
shape[1] = tmpB[2];
} else {
// we have new 1D shape here
auto newShape = ShapeBuilders::createVectorShapeInfo(dtype, tmpA[1], workspace);
RELEASE(shape, workspace);
RELEASE(tmpA, workspace);
RELEASE(tmpB, workspace);
return newShape;
}
} else if ((shape::isMatrix(tmpA) && shape::isMatrix(tmpB)) ||
(shape::isVector(tmpA) && shape::isMatrix(tmpB)) ||
(shape::isColumnVector(tmpA) && shape::isVector(tmpB))) {
// gemm case
shape[0] = tmpA[1];
shape[1] = tmpB[2];
} else if ((shape::isVector(tmpA) && shape::isScalar(tmpB)) ||
(shape::isScalar(tmpA) && shape::isVector(tmpB))) {
// element-wise
shape[0] = 1;
shape[1] = (int) sd::math::nd4j_max<Nd4jLong>(shape::length(tmpA), shape::length(tmpB));
} else if (shape::isRowVector(tmpA) && shape::isRowVector(tmpB)) {
// dot case
shape[0] = 1;
shape[1] = 1;
} else if (shape::isRowVector(tmpA) && shape::isColumnVector(tmpB)) {
// dot case
shape[0] = 1;
shape[1] = 1;
} }
} else if ((shape::isMatrix(tmpA) && shape::isMatrix(tmpB)) ||
(shape::isVector(tmpA) && shape::isMatrix(tmpB)) || auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'f', 2, shape);
(shape::isColumnVector(tmpA) && shape::isVector(tmpB))) {
// gemm case RELEASE(shape, workspace);
shape[0] = tmpA[1];
shape[1] = tmpB[2]; RELEASE(tmpA, workspace);
} else if ((shape::isVector(tmpA) && shape::isScalar(tmpB)) || RELEASE(tmpB, workspace);
(shape::isScalar(tmpA) && shape::isVector(tmpB))) { return newShape;
// element-wise
shape[0] = 1;
shape[1] = (int) sd::math::nd4j_max<Nd4jLong>(shape::length(tmpA), shape::length(tmpB));
} else if (shape::isRowVector(tmpA) && shape::isRowVector(tmpB)) {
// dot case
shape[0] = 1;
shape[1] = 1;
} else if (shape::isRowVector(tmpA) && shape::isColumnVector(tmpB)) {
// dot case
shape[0] = 1;
shape[1] = 1;
} }
Nd4jLong *newShape = ShapeBuilders::createShapeInfo(dtype, 'f', 2, shape, workspace);
RELEASE(shape, workspace);
RELEASE(tmpA, workspace);
RELEASE(tmpB, workspace);
return newShape;
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
std::vector<int> ShapeUtils::evalPermutFromTo(const std::vector<Nd4jLong>& shapeFrom, const std::vector<Nd4jLong>& shapeTo) { std::vector<int> ShapeUtils::evalPermutFromTo(const std::vector<Nd4jLong>& shapeFrom, const std::vector<Nd4jLong>& shapeTo) {
auto rank = shapeFrom.size(); auto rank = shapeFrom.size();

View File

@ -65,7 +65,7 @@ namespace shape {
* the information on an ndarray * the information on an ndarray
*/ */
struct ND4J_EXPORT ShapeInformation { struct ND4J_EXPORT ShapeInformation {
_CUDA_HD ShapeInformation(Nd4jLong *shape_ = nullptr, Nd4jLong *stride_ = nullptr, char order_ = 0, int rank_ = 0, int offset_ = 0, int elementWiseStride_ = 0) _CUDA_HD ShapeInformation(Nd4jLong* shape_ = nullptr, Nd4jLong *stride_ = nullptr, char order_ = 0, int rank_ = 0, int offset_ = 0, int elementWiseStride_ = 0)
: shape(shape_), stride(stride_), order(order_), rank(rank_), offset(offset_), elementWiseStride(elementWiseStride_) : shape(shape_), stride(stride_), order(order_), rank(rank_), offset(offset_), elementWiseStride(elementWiseStride_)
{} {}
@ -93,19 +93,19 @@ namespace shape {
ND4J_EXPORT _CUDA_HD bool shapeEquals(const int shape1Rank, const Nd4jLong *shape1, const int shape2Rank, const Nd4jLong *shape2); ND4J_EXPORT _CUDA_HD bool shapeEquals(const int shape1Rank, const Nd4jLong *shape1, const int shape2Rank, const Nd4jLong *shape2);
ND4J_EXPORT _CUDA_HD Nd4jLong* detachShape(Nd4jLong *originalShape); ND4J_EXPORT _CUDA_HD const Nd4jLong* detachShape(const Nd4jLong *originalShape);
ND4J_EXPORT _CUDA_HD Nd4jLong* copyShape(Nd4jLong *originalShape); ND4J_EXPORT _CUDA_HD Nd4jLong* copyShape(Nd4jLong const* originalShape);
ND4J_EXPORT _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2); ND4J_EXPORT _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2);
ND4J_EXPORT _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3); ND4J_EXPORT _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3);
ND4J_EXPORT _CUDA_HD bool strideEquals(int shape1Rank,Nd4jLong *shape1,int shape2Rank,Nd4jLong *shape2); ND4J_EXPORT _CUDA_HD bool strideEquals(int const shape1Rank,Nd4jLong const* shape1,int const shape2Rank, Nd4jLong const* shape2);
ND4J_EXPORT _CUDA_HD bool strideEquals(Nd4jLong *shapeInfo1,Nd4jLong *shapeInfo2); ND4J_EXPORT _CUDA_HD bool strideEquals(Nd4jLong const* shapeInfo1, Nd4jLong const* shapeInfo2);
ND4J_EXPORT _CUDA_HD bool strideEquals(Nd4jLong *stride1,int rank1,Nd4jLong *stride2,int rank2); ND4J_EXPORT _CUDA_HD bool strideEquals(Nd4jLong const* stride1,int const rank1, Nd4jLong const* stride2, int const rank2);
ND4J_EXPORT _CUDA_HD bool equalsSoft(const Nd4jLong *shapeA, const Nd4jLong *shapeB); ND4J_EXPORT _CUDA_HD bool equalsSoft(const Nd4jLong *shapeA, const Nd4jLong *shapeB);
@ -128,7 +128,7 @@ namespace shape {
ND4J_EXPORT _CUDA_HD int tadIndexForLinear(int linearIndex, int tadLength); ND4J_EXPORT _CUDA_HD int tadIndexForLinear(int linearIndex, int tadLength);
ND4J_EXPORT _CUDA_HD Nd4jLong tadLength(Nd4jLong *shapeInfo, int *dimension, int dimensionLength); ND4J_EXPORT _CUDA_HD Nd4jLong tadLength(const Nd4jLong *shapeInfo, int *dimension, int dimensionLength);
ND4J_EXPORT _CUDA_HD bool canReshape(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShape, bool isFOrder); ND4J_EXPORT _CUDA_HD bool canReshape(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShape, bool isFOrder);
@ -142,17 +142,17 @@ namespace shape {
* Get the shape info buffer * Get the shape info buffer
* for the given rank and shape. * for the given rank and shape.
*/ */
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong *shape); ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong const* shape);
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong *shape, Nd4jLong *buffer); ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong const* shape, Nd4jLong *buffer);
/** /**
* Get the shape info buffer * Get the shape info buffer
* for the given rank and shape. * for the given rank and shape.
*/ */
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong *shape); ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong const* shape);
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong *shape, Nd4jLong *output); ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong const* shape, Nd4jLong *output);
#ifdef __CUDACC__ #ifdef __CUDACC__
@ -168,9 +168,9 @@ namespace shape {
* @param startNum the start number for the strides * @param startNum the start number for the strides
* @return the strides for a matrix of n dimensions * @return the strides for a matrix of n dimensions
*/ */
ND4J_EXPORT _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong *shape, int rank); ND4J_EXPORT _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank);
ND4J_EXPORT _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong *shape, int rank, Nd4jLong* ret); ND4J_EXPORT _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank, Nd4jLong* ret);
/** /**
* Computes the standard packed array strides for a given shape. * Computes the standard packed array strides for a given shape.
@ -180,9 +180,9 @@ namespace shape {
* @return the strides for a matrix of n dimensions * @return the strides for a matrix of n dimensions
*/ */
ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank); ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank);
ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank, Nd4jLong* ret); ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank, Nd4jLong* ret);
ND4J_EXPORT _CUDA_HD void updateStrides(Nd4jLong *shape, const char order); ND4J_EXPORT _CUDA_HD void updateStrides(Nd4jLong *shape, const char order);
ND4J_EXPORT _CUDA_HD void updateStrides(const int rank, const Nd4jLong *shapeOnly, Nd4jLong *stridesOnly, const char order); ND4J_EXPORT _CUDA_HD void updateStrides(const int rank, const Nd4jLong *shapeOnly, Nd4jLong *stridesOnly, const char order);
@ -199,9 +199,9 @@ namespace shape {
* @param startNum the start number for the strides * @param startNum the start number for the strides
* @return the strides for a matrix of n dimensions * @return the strides for a matrix of n dimensions
*/ */
ND4J_EXPORT _CUDA_HD Nd4jLong* calcStridesFortran(Nd4jLong *shape, int rank, int startNum); ND4J_EXPORT _CUDA_HD Nd4jLong* calcStridesFortran(Nd4jLong const *shape, int rank, int startNum);
ND4J_EXPORT _CUDA_HD Nd4jLong* calcStridesFortran(Nd4jLong *shape, int rank, int startNum, Nd4jLong* ret); ND4J_EXPORT _CUDA_HD Nd4jLong* calcStridesFortran(Nd4jLong const *shape, int rank, int startNum, Nd4jLong* ret);
/** /**
* Computes the standard packed array strides for a given shape. * Computes the standard packed array strides for a given shape.
@ -210,9 +210,9 @@ namespace shape {
* @param startNum the start number for the strides * @param startNum the start number for the strides
* @return the strides for a matrix of n dimensions * @return the strides for a matrix of n dimensions
*/ */
ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank, int startNum); ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const* shape, int rank, int startNum);
ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank, int startNum, Nd4jLong* ret); ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank, int startNum, Nd4jLong* ret);
/** /**
* @param toCopy the shape to copy * @param toCopy the shape to copy
@ -244,7 +244,7 @@ namespace shape {
* @return 0 if there is no element wise stride the * @return 0 if there is no element wise stride the
* element wise stride of reshape(1,length) otherwise * element wise stride of reshape(1,length) otherwise
*/ */
ND4J_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong *shape, Nd4jLong *stride, int isFOrder); ND4J_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const* shape, Nd4jLong const* stride, int isFOrder);
/** /**
* Compute the element wise stride * Compute the element wise stride
@ -257,11 +257,11 @@ namespace shape {
* @return 0 if there is no element wise stride the * @return 0 if there is no element wise stride the
* element wise stride of reshape(1,length) otherwise * element wise stride of reshape(1,length) otherwise
*/ */
ND4J_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong *shape, Nd4jLong *stride, int isFOrder, Nd4jLong *dimension, int dimensionLength); ND4J_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const* shape, Nd4jLong const* stride, int isFOrder, Nd4jLong const* dimension, int dimensionLength);
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride); ND4J_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(Nd4jLong const* shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride);
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride, Nd4jLong *buffer); ND4J_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(const Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride, Nd4jLong *buffer);
/** /**
* *
* @param length * @param length
@ -281,7 +281,7 @@ namespace shape {
*/ */
ND4J_EXPORT _CUDA_HD void doPermuteSwap(int length, Nd4jLong **shape, int* rearrange); ND4J_EXPORT _CUDA_HD void doPermuteSwap(int length, Nd4jLong **shape, int* rearrange);
ND4J_EXPORT _CUDA_HD Nd4jLong *permuteShapeBuffer(Nd4jLong *shapeBuffer, int* rearrange); ND4J_EXPORT _CUDA_HD Nd4jLong *permuteShapeBuffer(Nd4jLong const* shapeBuffer, int* rearrange);
ND4J_EXPORT _CUDA_HD void permuteShapeBufferInPlace(Nd4jLong *shapeBuffer, int* rearrange, Nd4jLong *out); ND4J_EXPORT _CUDA_HD void permuteShapeBufferInPlace(Nd4jLong *shapeBuffer, int* rearrange, Nd4jLong *out);
@ -304,7 +304,7 @@ namespace shape {
ND4J_EXPORT _CUDA_HD Nd4jLong* createPermuteIndexes(int originalRank, int *dimension,int dimensionLength); ND4J_EXPORT _CUDA_HD Nd4jLong* createPermuteIndexes(int originalRank, int *dimension,int dimensionLength);
ND4J_EXPORT _CUDA_HD Nd4jLong* computeResultShape(Nd4jLong *originalShapeBuffer, int *dimension,int dimensionLength); ND4J_EXPORT _CUDA_HD Nd4jLong* computeResultShape(const Nd4jLong *originalShapeBuffer, int *dimension,int dimensionLength);
/** /**
* This method does inplace transpose of given shapeBuffer * This method does inplace transpose of given shapeBuffer
@ -350,7 +350,7 @@ namespace shape {
* @param shape the shape of the array * @param shape the shape of the array
* @param rank the rank of cthe shape * @param rank the rank of cthe shape
*/ */
ND4J_EXPORT _CUDA_HD int isVector(Nd4jLong *shape, int rank); ND4J_EXPORT _CUDA_HD int isVector(Nd4jLong const* shape, int rank);
/** /**
@ -363,13 +363,13 @@ namespace shape {
ND4J_EXPORT _CUDA_HD int isVector(const Nd4jLong *shapeInfo); ND4J_EXPORT _CUDA_HD int isVector(const Nd4jLong *shapeInfo);
ND4J_EXPORT _CUDA_HD bool isLikeVector(Nd4jLong *shapeInfo, int& posOfNonUnityDim); ND4J_EXPORT _CUDA_HD bool isLikeVector(Nd4jLong const* shapeInfo, int& posOfNonUnityDim);
ND4J_EXPORT _CUDA_HD bool isCommonVector(const Nd4jLong *shapeInfo, int& posOfNonUnityDim); ND4J_EXPORT _CUDA_HD bool isCommonVector(const Nd4jLong *shapeInfo, int& posOfNonUnityDim);
ND4J_EXPORT _CUDA_HD bool isRowVector(const Nd4jLong *shapeInfo); ND4J_EXPORT _CUDA_HD bool isRowVector(const Nd4jLong *shapeInfo);
ND4J_EXPORT _CUDA_HD bool isColumnVector(Nd4jLong *shapeInfo); ND4J_EXPORT _CUDA_HD bool isColumnVector(Nd4jLong const* shapeInfo);
/** /**
* shape - input inShape is shape only, not shapeInfo * shape - input inShape is shape only, not shapeInfo
@ -401,10 +401,10 @@ namespace shape {
*/ */
template <typename T> template <typename T>
ND4J_EXPORT _CUDA_HD T* copyOf(Nd4jLong length, T *toCopy); ND4J_EXPORT _CUDA_HD T* copyOf(Nd4jLong length, T const* toCopy);
template <typename T> template <typename T>
ND4J_EXPORT _CUDA_HD T* copyOf(Nd4jLong length, T *toCopy, T *ret); ND4J_EXPORT _CUDA_HD T* copyOf(Nd4jLong length, T const* toCopy, T *ret);
/** /**
* Return a copy of a buffer. * Return a copy of a buffer.
@ -413,13 +413,13 @@ namespace shape {
*/ */
template <typename T> template <typename T>
ND4J_EXPORT _CUDA_HD void copyTo(Nd4jLong length, T *from, T *to); ND4J_EXPORT _CUDA_HD void copyTo(Nd4jLong length, T const* from, T *to);
/** /**
* Return a copy of a buffer. * Return a copy of a buffer.
* This buffer allocates memory * This buffer allocates memory
* that must be freed elsewhere. * that must be freed elsewhere.
*/ */
ND4J_EXPORT _CUDA_HD void copyTo(int length, Nd4jLong *from, Nd4jLong *to, Nd4jLong *indexes); ND4J_EXPORT _CUDA_HD void copyTo(int length, Nd4jLong const* from, Nd4jLong *to, Nd4jLong *indexes);
/** /**
* Permute the given strides * Permute the given strides
@ -566,7 +566,7 @@ namespace shape {
* item * item
*/ */
template <typename T1, typename T2> template <typename T1, typename T2>
ND4J_EXPORT _CUDA_HD void removeIndex(T1 *data, T2 *indexes, Nd4jLong dataLength, Nd4jLong indexesLength, T1 *out); ND4J_EXPORT _CUDA_HD void removeIndex(T1 const* data, T2 const* indexes, Nd4jLong dataLength, Nd4jLong indexesLength, T1 *out);
/** /**
* Return a copy of this array with the * Return a copy of this array with the
@ -582,7 +582,7 @@ namespace shape {
*/ */
template <typename T1, typename T2> template <typename T1, typename T2>
ND4J_EXPORT _CUDA_HD T1* removeIndex(T1 *data, T2 *indexes, Nd4jLong dataLength, Nd4jLong indexesLength); ND4J_EXPORT _CUDA_HD T1* removeIndex(T1 const* data, T2 const* indexes, Nd4jLong dataLength, Nd4jLong indexesLength);
/** /**
* Iterate over a given set of indexes * Iterate over a given set of indexes
@ -595,7 +595,7 @@ namespace shape {
* indexes should be the indexes to exclude * indexes should be the indexes to exclude
* indexes length should be the length of indexes * indexes length should be the length of indexes
*/ */
ND4J_EXPORT _CUDA_HD Nd4jLong* everyIndexBut(Nd4jLong *indexes,int indexesLength,int begin,int end); ND4J_EXPORT _CUDA_HD Nd4jLong* everyIndexBut(Nd4jLong const* indexes,int indexesLength,int begin,int end);
/** /**
* Computes the offset for accessing * Computes the offset for accessing
@ -641,7 +641,7 @@ namespace shape {
* Keep the given indexes * Keep the given indexes
* in the data * in the data
*/ */
ND4J_EXPORT _CUDA_HD Nd4jLong *keep(volatile Nd4jLong *data, int* index, int indexLength, int dataLength); ND4J_EXPORT _CUDA_HD Nd4jLong *keep(volatile Nd4jLong *data, int const* index, int indexLength, int dataLength);
/** /**
* Generate reverse copy of the data * Generate reverse copy of the data
@ -651,13 +651,13 @@ namespace shape {
*/ */
template <typename T> template <typename T>
ND4J_EXPORT _CUDA_HD T* reverseCopy(T *data, Nd4jLong length); ND4J_EXPORT _CUDA_HD T* reverseCopy(T const* data, Nd4jLong length);
template <typename T> template <typename T>
ND4J_EXPORT _CUDA_HD void reverseCopyTo(T *from, T *to, Nd4jLong length); ND4J_EXPORT _CUDA_HD void reverseCopyTo(T const* from, T *to, Nd4jLong length);
template <typename T> template <typename T>
ND4J_EXPORT _CUDA_HD void reverseCopyTo(T *from, T *to, Nd4jLong *indexes, Nd4jLong length); ND4J_EXPORT _CUDA_HD void reverseCopyTo(T const* from, T *to, Nd4jLong *indexes, Nd4jLong length);
template <typename T1, typename T2> template <typename T1, typename T2>
ND4J_EXPORT _CUDA_H void convertT(T1 *from, T2 *to, Nd4jLong length); ND4J_EXPORT _CUDA_H void convertT(T1 *from, T2 *to, Nd4jLong length);
@ -670,7 +670,7 @@ namespace shape {
* @return * @return
*/ */
template <typename T> template <typename T>
ND4J_EXPORT _CUDA_HD T* concat(T* arr1, Nd4jLong arr1Length, T* arr2, Nd4jLong arr2Length); ND4J_EXPORT _CUDA_HD T* concat(T const* arr1, Nd4jLong const arr1Length, T const* arr2, Nd4jLong const arr2Length);
/** /**
* *
@ -681,7 +681,7 @@ namespace shape {
* @return * @return
*/ */
template <typename T> template <typename T>
ND4J_EXPORT _CUDA_HD T* concat(int numArrays, int numTotalElements, Nd4jLong **arr, Nd4jLong *lengths); ND4J_EXPORT _CUDA_HD T* concat(int const numArrays, int const numTotalElements, Nd4jLong const**arr, Nd4jLong const* lengths);
/** /**
* Get the length per slice of the * Get the length per slice of the
@ -695,7 +695,7 @@ namespace shape {
* @return the length per slice of the given shape * @return the length per slice of the given shape
* along the given dimension * along the given dimension
*/ */
ND4J_EXPORT _CUDA_HD Nd4jLong lengthPerSlice(int rank, Nd4jLong *shape, int *dimension, int dimensionLength); ND4J_EXPORT _CUDA_HD Nd4jLong lengthPerSlice(int rank, Nd4jLong const* shape, int const* dimension, int dimensionLength);
/** /**
* calculates the offset for a tensor * calculates the offset for a tensor
@ -706,10 +706,10 @@ namespace shape {
*/ */
ND4J_EXPORT _CUDA_HD Nd4jLong sliceOffsetForTensor(int rank, ND4J_EXPORT _CUDA_HD Nd4jLong sliceOffsetForTensor(int rank,
int index, int index,
Nd4jLong *shape, Nd4jLong const* shape,
Nd4jLong *tensorShape, Nd4jLong const* tensorShape,
int tensorShapeLength, int tensorShapeLength,
int *dimension, int const *dimension,
int dimensionLength); int dimensionLength);
/** /**
@ -1095,7 +1095,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
* Length of a tad given * Length of a tad given
* the shape information * the shape information
*/ */
INLINEDEF _CUDA_HD Nd4jLong tadLength(Nd4jLong *shapeInfo, int *dimension, int dimensionLength) { INLINEDEF _CUDA_HD Nd4jLong tadLength(const Nd4jLong *shapeInfo, int *dimension, int dimensionLength) {
if(dimensionLength == 1) { if(dimensionLength == 1) {
return shape::shapeOf(shapeInfo)[dimension[0]]; return shape::shapeOf(shapeInfo)[dimension[0]];
} }
@ -1166,7 +1166,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
} }
INLINEDEF _CUDA_HD bool strideEquals(int shape1Rank,Nd4jLong *shape1,int shape2Rank,Nd4jLong *shape2) { INLINEDEF _CUDA_HD bool strideEquals(int const shape1Rank, Nd4jLong const* shape1,int const shape2Rank,Nd4jLong const* shape2) {
if(shape1Rank != shape2Rank) if(shape1Rank != shape2Rank)
return false; return false;
//rank not equals //rank not equals
@ -1178,12 +1178,12 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
return true; return true;
} }
INLINEDEF _CUDA_HD bool strideEquals(Nd4jLong *shapeInfo1,Nd4jLong *shapeInfo2) { INLINEDEF _CUDA_HD bool strideEquals(Nd4jLong const* shapeInfo1,Nd4jLong const* shapeInfo2) {
return shape::strideEquals(shape::rank(shapeInfo1),shape::stride(shapeInfo1),shape::rank(shapeInfo2),shape::stride(shapeInfo2)); return shape::strideEquals(shape::rank(shapeInfo1),shape::stride(shapeInfo1),shape::rank(shapeInfo2),shape::stride(shapeInfo2));
} }
INLINEDEF _CUDA_HD bool strideEquals(Nd4jLong *stride1,int rank1 , Nd4jLong *stride2, int rank2) { INLINEDEF _CUDA_HD bool strideEquals(Nd4jLong const* stride1,int const rank1 , Nd4jLong const* stride2, int const rank2) {
if(rank1 != rank2) if(rank1 != rank2)
return false; return false;
@ -1195,7 +1195,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
return true; return true;
} }
INLINEDEF _CUDA_HD Nd4jLong *computeResultShape(Nd4jLong *originalShapeBuffer, int* dimension,int dimensionLength) { INLINEDEF _CUDA_HD Nd4jLong *computeResultShape(Nd4jLong const* originalShapeBuffer, int * dimension,int dimensionLength) {
Nd4jLong *retShape; Nd4jLong *retShape;
int retShapeLength; int retShapeLength;
if(dimensionLength == 1 && dimension[0] == 2147483647) { if(dimensionLength == 1 && dimension[0] == 2147483647) {
@ -1236,7 +1236,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
} }
INLINEDEF _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride, Nd4jLong *buffer) { INLINEDEF _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(const Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride, Nd4jLong *buffer) {
Nd4jLong *theShape = shape::shapeOf(shapeInfo); Nd4jLong *theShape = shape::shapeOf(shapeInfo);
Nd4jLong *theStride = shape::stride(shapeInfo); Nd4jLong *theStride = shape::stride(shapeInfo);
int rank = dimensionLength == 1 ? 2 : dimensionLength; int rank = dimensionLength == 1 ? 2 : dimensionLength;
@ -1279,7 +1279,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
} }
else { else {
Nd4jLong *newIndexes = dimension; Nd4jLong *newIndexes = dimension;
if(reverseCopyStride) if(reverseCopyStride)
shape::reverseCopyTo(theStride, retStride, newIndexes, len); shape::reverseCopyTo(theStride, retStride, newIndexes, len);
else else
@ -1293,7 +1293,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
return ret; return ret;
} }
INLINEDEF _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride) { INLINEDEF _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(const Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride) {
int rank = dimensionLength == 1 ? 2 : dimensionLength; int rank = dimensionLength == 1 ? 2 : dimensionLength;
traceNew(4); traceNew(4);
@ -1330,7 +1330,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
* @param startNum the start number for the strides * @param startNum the start number for the strides
* @return the strides for a matrix of n dimensions * @return the strides for a matrix of n dimensions
*/ */
INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong *shape, int rank, int startNum) { INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank, int startNum) {
if (isVector(shape, rank)) { if (isVector(shape, rank)) {
traceNew(5); traceNew(5);
@ -1356,7 +1356,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
return stride; return stride;
} }
INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong *shape, int rank, int startNum, Nd4jLong *ret) { INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank, int startNum, Nd4jLong *ret) {
if (isVector(shape, rank)) { if (isVector(shape, rank)) {
for (int i = 0; i < rank; i++) for (int i = 0; i < rank; i++)
ret[i] = 1; ret[i] = 1;
@ -1382,7 +1382,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
* @param startNum the start number for the strides * @param startNum the start number for the strides
* @return the strides for a matrix of n dimensions * @return the strides for a matrix of n dimensions
*/ */
INLINEDEF _CUDA_HD Nd4jLong * calcStrides(Nd4jLong *shape, int rank, int startNum) { INLINEDEF _CUDA_HD Nd4jLong * calcStrides(Nd4jLong const *shape, int rank, int startNum) {
traceNew(7); traceNew(7);
@ -1410,7 +1410,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
return stride; return stride;
} }
INLINEDEF _CUDA_HD Nd4jLong * calcStrides(Nd4jLong *shape, int rank, int startNum, Nd4jLong* ret) { INLINEDEF _CUDA_HD Nd4jLong * calcStrides(Nd4jLong const* shape, int rank, int startNum, Nd4jLong* ret) {
if (rank == 1) { if (rank == 1) {
ret[0] = 1; ret[0] = 1;
return ret; return ret;
@ -1439,11 +1439,11 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
* @param startNum the start number for the strides * @param startNum the start number for the strides
* @return the strides for a matrix of n dimensions * @return the strides for a matrix of n dimensions
*/ */
INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong *shape, int rank) { INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank) {
return calcStridesFortran(shape, rank, 1); return calcStridesFortran(shape, rank, 1);
} }
INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong *shape, int rank, Nd4jLong* ret) { INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank, Nd4jLong* ret) {
return calcStridesFortran(shape, rank, 1, ret); return calcStridesFortran(shape, rank, 1, ret);
} }
@ -1454,11 +1454,11 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
* @param startNum the start number for the strides * @param startNum the start number for the strides
* @return the strides for a matrix of n dimensions * @return the strides for a matrix of n dimensions
*/ */
INLINEDEF _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank) { INLINEDEF _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank) {
return calcStrides(shape, rank, 1); return calcStrides(shape, rank, 1);
} }
INLINEDEF _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank, Nd4jLong* ret) { INLINEDEF _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank, Nd4jLong* ret) {
return calcStrides(shape, rank, 1, ret); return calcStrides(shape, rank, 1, ret);
} }
@ -1541,7 +1541,7 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
return copy; return copy;
} }
INLINEDEF _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong *shape, Nd4jLong *stride, int isFOrder) { INLINEDEF _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const* shape, Nd4jLong const* stride, int isFOrder) {
if (rank == 0) if (rank == 0)
return 1; return 1;
@ -1690,8 +1690,8 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
} }
INLINEDEF _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong *shape, Nd4jLong *stride, int isFOrder, INLINEDEF _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const* shape, Nd4jLong const* stride, int isFOrder,
Nd4jLong *dimension, int dimensionLength) { Nd4jLong const* dimension, int dimensionLength) {
if(dimensionLength == 1) { if(dimensionLength == 1) {
return stride[dimension[0]]; return stride[dimension[0]];
} }
@ -1703,13 +1703,13 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
* Get the shape info buffer * Get the shape info buffer
* for the given rank and shape. * for the given rank and shape.
*/ */
INLINEDEF _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong *shape) { INLINEDEF _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong const* shape) {
Nd4jLong *stride = shape::calcStrides(shape, rank); Nd4jLong *stride = shape::calcStrides(shape, rank);
traceNew(11); traceNew(11);
auto shapeInfo = new shape::ShapeInformation(); auto shapeInfo = new shape::ShapeInformation();
shapeInfo->shape = shape; shapeInfo->shape = const_cast<Nd4jLong*>(shape);
shapeInfo->stride = stride; shapeInfo->stride = stride;
shapeInfo->offset = 0; shapeInfo->offset = 0;
shapeInfo->rank = rank; shapeInfo->rank = rank;
@ -1728,13 +1728,13 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
* *
* This method is used only for SoftMax * This method is used only for SoftMax
*/ */
INLINEDEF _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong *shape, Nd4jLong *buffer) { INLINEDEF _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong const* shape, Nd4jLong *buffer) {
Nd4jLong stride[MAX_RANK]; Nd4jLong stride[MAX_RANK];
shape::calcStrides(shape,rank, stride); shape::calcStrides(shape,rank, stride);
shape::ShapeInformation shapeInfo; shape::ShapeInformation shapeInfo;
shapeInfo.shape = shape; shapeInfo.shape = const_cast<Nd4jLong*>(shape);
shapeInfo.stride = stride; shapeInfo.stride = stride;
shapeInfo.offset = 0; shapeInfo.offset = 0;
shapeInfo.rank = rank; shapeInfo.rank = rank;
@ -1751,13 +1751,13 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
* Get the shape info buffer * Get the shape info buffer
* for the given rank and shape. * for the given rank and shape.
*/ */
INLINEDEF _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong *shape) { INLINEDEF _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong const* shape) {
auto stride = shape::calcStridesFortran(shape,rank); auto stride = shape::calcStridesFortran(shape,rank);
traceNew(12); traceNew(12);
auto shapeInfo = new shape::ShapeInformation(); auto shapeInfo = new shape::ShapeInformation();
shapeInfo->shape = shape; shapeInfo->shape = const_cast<Nd4jLong*>(shape);
shapeInfo->stride = stride; shapeInfo->stride = stride;
shapeInfo->offset = 0; shapeInfo->offset = 0;
shapeInfo->rank = rank; shapeInfo->rank = rank;
@ -1772,13 +1772,13 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
return shapeInfoBuffer; return shapeInfoBuffer;
} }
INLINEDEF _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong *shape, Nd4jLong *output) { INLINEDEF _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong const *shape, Nd4jLong *output) {
Nd4jLong stride[MAX_RANK]; Nd4jLong stride[MAX_RANK];
shape::calcStridesFortran(shape,rank, stride); shape::calcStridesFortran(shape,rank, stride);
shape::ShapeInformation shapeInfo; shape::ShapeInformation shapeInfo;
shapeInfo.shape = shape; shapeInfo.shape = const_cast<Nd4jLong*>(shape);
shapeInfo.stride = stride; shapeInfo.stride = stride;
shapeInfo.offset = 0; shapeInfo.offset = 0;
shapeInfo.rank = rank; shapeInfo.rank = rank;
@ -2049,7 +2049,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
shape::doPermuteShapeInfo(out, rearrange); shape::doPermuteShapeInfo(out, rearrange);
} }
INLINEDEF _CUDA_HD Nd4jLong *permuteShapeBuffer(Nd4jLong *shapeBuffer, int* rearrange) { INLINEDEF _CUDA_HD Nd4jLong *permuteShapeBuffer(Nd4jLong const* shapeBuffer, int* rearrange) {
auto len = shape::shapeInfoLength(shape::rank(shapeBuffer)); auto len = shape::shapeInfoLength(shape::rank(shapeBuffer));
Nd4jLong *copy = shape::copyOf(len, shapeBuffer); Nd4jLong *copy = shape::copyOf(len, shapeBuffer);
shape::doPermuteShapeInfo(copy,rearrange); shape::doPermuteShapeInfo(copy,rearrange);
@ -2238,7 +2238,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
* @param shape the shape of the array * @param shape the shape of the array
* @param rank the rank of the shape * @param rank the rank of the shape
*/ */
INLINEDEF _CUDA_HD int isVector(Nd4jLong *shape, int rank) { INLINEDEF _CUDA_HD int isVector(Nd4jLong const* shape, int rank) {
if (rank == 0) if (rank == 0)
return 0; return 0;
@ -2254,7 +2254,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
return 0; return 0;
} }
INLINEDEF _CUDA_HD bool isLikeVector(Nd4jLong *shapeInfo, int& posOfNonUnityDim) { INLINEDEF _CUDA_HD bool isLikeVector(Nd4jLong const* shapeInfo, int& posOfNonUnityDim) {
int numOfNonUnity = 0; int numOfNonUnity = 0;
for(int i = 1; i <= shapeInfo[0]; ++i) { for(int i = 1; i <= shapeInfo[0]; ++i) {
@ -2284,7 +2284,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
return numOfNonUnity == 1; return numOfNonUnity == 1;
} }
INLINEDEF _CUDA_H Nd4jLong* detachShape(Nd4jLong *originalShape) { INLINEDEF _CUDA_H Nd4jLong const* detachShape(Nd4jLong const* originalShape) {
Nd4jLong *newShape = new Nd4jLong[shape::shapeInfoLength(originalShape)]; Nd4jLong *newShape = new Nd4jLong[shape::shapeInfoLength(originalShape)];
memcpy(newShape, originalShape, shape::shapeInfoByteLength(originalShape)); memcpy(newShape, originalShape, shape::shapeInfoByteLength(originalShape));
@ -2292,7 +2292,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
} }
INLINEDEF _CUDA_H Nd4jLong* copyShape(Nd4jLong *originalShape) { INLINEDEF _CUDA_H Nd4jLong* copyShape(Nd4jLong const* originalShape) {
Nd4jLong *newShape = new Nd4jLong[shape::shapeInfoLength(originalShape)]; Nd4jLong *newShape = new Nd4jLong[shape::shapeInfoLength(originalShape)];
memcpy(newShape, originalShape, shape::shapeInfoByteLength(originalShape)); memcpy(newShape, originalShape, shape::shapeInfoByteLength(originalShape));
@ -2309,7 +2309,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
return isVector && shapeFirstOne; return isVector && shapeFirstOne;
} }
INLINEDEF _CUDA_HD bool isColumnVector(Nd4jLong *shapeInfo) { INLINEDEF _CUDA_HD bool isColumnVector(const Nd4jLong *shapeInfo) {
bool isVector = shape::isVector(shapeInfo) == 1; bool isVector = shape::isVector(shapeInfo) == 1;
bool shapeFirstOne = shape::shapeOf(shapeInfo)[0] == 1; bool shapeFirstOne = shape::shapeOf(shapeInfo)[0] == 1;
return isVector && !shapeFirstOne; return isVector && !shapeFirstOne;
@ -2381,7 +2381,7 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape)
* that must be freed elsewhere. * that must be freed elsewhere.
*/ */
template <typename T> template <typename T>
INLINEDEF _CUDA_HD T *copyOf(Nd4jLong length, T *toCopy) { INLINEDEF _CUDA_HD T *copyOf(Nd4jLong length, T const* toCopy) {
traceNew(18); traceNew(18);
T *ret = new T[length]; T *ret = new T[length];
@ -2389,7 +2389,7 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape)
} }
template <typename T> template <typename T>
INLINEDEF _CUDA_HD T* copyOf(Nd4jLong length, T *toCopy, T *ret) { INLINEDEF _CUDA_HD T* copyOf(Nd4jLong length, T const* toCopy, T *ret) {
memcpy(ret, toCopy, sizeof(T)*length); memcpy(ret, toCopy, sizeof(T)*length);
return ret; return ret;
} }
@ -2400,7 +2400,7 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape)
* that must be freed elsewhere. * that must be freed elsewhere.
*/ */
template <typename T> template <typename T>
INLINEDEF _CUDA_HD void copyTo(Nd4jLong length, T *from, T *to) { INLINEDEF _CUDA_HD void copyTo(Nd4jLong length, T const* from, T *to) {
memcpy(to, from, sizeof(T)*length); memcpy(to, from, sizeof(T)*length);
} }
@ -2409,7 +2409,7 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape)
* This buffer allocates memory * This buffer allocates memory
* that must be freed elsewhere. * that must be freed elsewhere.
*/ */
INLINEDEF _CUDA_HD void copyTo(int length, Nd4jLong *from, Nd4jLong *to, Nd4jLong *indexes) { INLINEDEF _CUDA_HD void copyTo(int length, Nd4jLong const* from, Nd4jLong *to, Nd4jLong *indexes) {
for(int i = 0; i < length; i++) { for(int i = 0; i < length; i++) {
to[i] = from[indexes[i]]; to[i] = from[indexes[i]];
} }
@ -2817,7 +2817,7 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape)
* item * item
*/ */
template <typename T1, typename T2> template <typename T1, typename T2>
INLINEDEF _CUDA_HD void removeIndex(T1* data, T2 *indexes, Nd4jLong dataLength, Nd4jLong indexesLength, T1 *ret) { INLINEDEF _CUDA_HD void removeIndex(T1 const* data, T2 const* indexes, Nd4jLong dataLength, Nd4jLong indexesLength, T1 *ret) {
int count = 0; int count = 0;
int absLength = dataLength - indexesLength; int absLength = dataLength - indexesLength;
@ -2850,7 +2850,7 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape)
* item * item
*/ */
template <typename T1, typename T2> template <typename T1, typename T2>
INLINEDEF _CUDA_HD T1* removeIndex(T1 *data, T2 *indexes, Nd4jLong dataLength, Nd4jLong indexesLength) { INLINEDEF _CUDA_HD T1* removeIndex(T1 const* data, T2 const* indexes, Nd4jLong dataLength, Nd4jLong indexesLength) {
auto lengthOfArr = dataLength - indexesLength; auto lengthOfArr = dataLength - indexesLength;
if(lengthOfArr < 0) { if(lengthOfArr < 0) {
printf("Remove index call created a <= 0 length array. This was likely not intended."); printf("Remove index call created a <= 0 length array. This was likely not intended.");
@ -2862,7 +2862,7 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape)
return ret; return ret;
} }
INLINEDEF _CUDA_HD Nd4jLong* everyIndexBut(Nd4jLong *indexes,int indexesLength,int begin,int end) { INLINEDEF _CUDA_HD Nd4jLong* everyIndexBut(const Nd4jLong *indexes,int indexesLength,int begin,int end) {
int len = end - indexesLength; int len = end - indexesLength;
traceNew(20); traceNew(20);
@ -3086,7 +3086,7 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
* @param dataLength * @param dataLength
* @return * @return
*/ */
INLINEDEF _CUDA_HD Nd4jLong *keep(volatile Nd4jLong *data, int* index, int indexLength, int dataLength) { INLINEDEF _CUDA_HD Nd4jLong *keep(volatile Nd4jLong *data, int const* index, int indexLength, int dataLength) {
traceNew(23); traceNew(23);
@ -3113,7 +3113,7 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
*/ */
template <typename T> template <typename T>
INLINEDEF _CUDA_HD T* reverseCopy(T *data, Nd4jLong length) { INLINEDEF _CUDA_HD T* reverseCopy(T const* data, Nd4jLong length) {
if (length < 1) if (length < 1)
return nullptr; return nullptr;
@ -3129,7 +3129,7 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
} }
template <typename T> template <typename T>
INLINEDEF _CUDA_HD void reverseCopyTo(T *from, T *to, Nd4jLong length) { INLINEDEF _CUDA_HD void reverseCopyTo(T const* from, T *to, Nd4jLong length) {
if (length < 1) if (length < 1)
return; return;
for (Nd4jLong i = 0; i <= length / 2; i++) { for (Nd4jLong i = 0; i <= length / 2; i++) {
@ -3140,7 +3140,7 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
} }
template <typename T> template <typename T>
INLINEDEF _CUDA_HD void reverseCopyTo(T *from, T *to, Nd4jLong *indexes, Nd4jLong length) { INLINEDEF _CUDA_HD void reverseCopyTo(T const* from, T *to, Nd4jLong *indexes, Nd4jLong length) {
if (length < 1) if (length < 1)
return; return;
@ -3161,7 +3161,7 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
* @return * @return
*/ */
template <typename T> template <typename T>
INLINEDEF _CUDA_HD T* concat(T* arr1, Nd4jLong arr1Length, T* arr2, Nd4jLong arr2Length) { INLINEDEF _CUDA_HD T* concat(T const* arr1, Nd4jLong const arr1Length, T const* arr2, Nd4jLong const arr2Length) {
traceNew(25); traceNew(25);
@ -3180,7 +3180,7 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
* @return * @return
*/ */
template <typename T> template <typename T>
INLINEDEF _CUDA_HD T *concat(Nd4jLong numArrays, Nd4jLong numTotalElements, T **arr, Nd4jLong *lengths) { INLINEDEF _CUDA_HD T *concat(Nd4jLong const numArrays, Nd4jLong const numTotalElements, T const **arr, Nd4jLong const *lengths) {
T* ret = new T[numTotalElements]; T* ret = new T[numTotalElements];
Nd4jLong count = 0; Nd4jLong count = 0;
@ -3206,7 +3206,7 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
* @return the length per slice of the given shape * @return the length per slice of the given shape
* along the given dimension * along the given dimension
*/ */
INLINEDEF _CUDA_HD Nd4jLong lengthPerSlice(int rank, Nd4jLong *shape, int* dimension, int dimensionLength) { INLINEDEF _CUDA_HD Nd4jLong lengthPerSlice(int rank, Nd4jLong const* shape, int const* dimension, int dimensionLength) {
if(shape::isVector(shape,rank)) { if(shape::isVector(shape,rank)) {
//return total length for row vectors //return total length for row vectors
if(dimensionLength == 1 && shape[0] == 1) { if(dimensionLength == 1 && shape[0] == 1) {
@ -3230,7 +3230,7 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
* @param tensorShape * @param tensorShape
* @return * @return
*/ */
INLINEDEF _CUDA_HD Nd4jLong sliceOffsetForTensor(int rank, int index, Nd4jLong *shape, Nd4jLong *tensorShape, int tensorShapeLength, int* dimension, int dimensionLength) { INLINEDEF _CUDA_HD Nd4jLong sliceOffsetForTensor(int rank, int index, Nd4jLong const* shape, Nd4jLong const* tensorShape, int tensorShapeLength, int const* dimension, int dimensionLength) {
auto tensorLength = prodLong(tensorShape, tensorShapeLength); auto tensorLength = prodLong(tensorShape, tensorShapeLength);
auto lengthPerSlice2 = lengthPerSlice(rank, shape, dimension, dimensionLength); auto lengthPerSlice2 = lengthPerSlice(rank, shape, dimension, dimensionLength);
if (lengthPerSlice2 <= 0) { if (lengthPerSlice2 <= 0) {

View File

@ -47,11 +47,11 @@ public:
*/ */
static void execIndexReduceScalar(sd::LaunchContext *lc, static void execIndexReduceScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo); void *dZ, const Nd4jLong *dZShapeInfo);
/** /**
* *
@ -68,13 +68,13 @@ public:
*/ */
static void execReduce3Scalar(sd::LaunchContext *lc, static void execReduce3Scalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo); void *dZ, const Nd4jLong *dZShapeInfo);
/** /**
@ -90,13 +90,13 @@ public:
*/ */
static void execReduce3(sd::LaunchContext *lc, static void execReduce3(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo); void *dZ, const Nd4jLong *dZShapeInfo);
/** /**
* *
@ -113,29 +113,29 @@ public:
*/ */
static void execReduce3(sd::LaunchContext *lc, static void execReduce3(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadOnlyShapeInfo, Nd4jLong *xTadOffsets, const Nd4jLong *xTadOnlyShapeInfo, const Nd4jLong *xTadOffsets,
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets); const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets);
static void execReduce3All(sd::LaunchContext *lc, static void execReduce3All(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets); const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets);
/** /**
* *
@ -150,13 +150,13 @@ public:
*/ */
static void execIndexReduce(sd::LaunchContext *lc, static void execIndexReduce(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
/** /**
* *
@ -170,73 +170,76 @@ public:
* @param n * @param n
*/ */
static void execScalar(sd::LaunchContext *lc, static void execScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hSscalarShapeInfo, const void *hScalar, const Nd4jLong *hSscalarShapeInfo,
void *dScalar, Nd4jLong *dSscalarShapeInfo, const void *dScalar, const Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism = true); void *extraParams,
bool allowParallelism = true);
static void execScalarBool(sd::LaunchContext *lc, static void execScalarBool(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hSscalarShapeInfo, const void *hScalar, const Nd4jLong *hSscalarShapeInfo,
void *dScalar, Nd4jLong *dSscalarShapeInfo, const void *dScalar, const Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism = true); void *extraParams,
bool allowParallelism = true);
static void execScalarInt(sd::LaunchContext *lc, static void execScalarInt(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hSscalarShapeInfo, const void *hScalar, const Nd4jLong *hSscalarShapeInfo,
void *dScalar, Nd4jLong *dSscalarShapeInfo, const void *dScalar, const Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism = true); void *extraParams,
bool allowParallelism = true);
static void execScalar(sd::LaunchContext *lc, static void execScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo, void const* hScalars, Nd4jLong const* hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo, void const* dScalars, Nd4jLong const* dScalarShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ);
static void execScalarBool(sd::LaunchContext *lc, static void execScalarBool(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo, const void *hScalars, const Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo, const void *dScalars, const Nd4jLong *dScalarShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ);
static void execScalarInt(sd::LaunchContext *lc, static void execScalarInt(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo, const void *hScalars, const Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo, const void *dScalars, const Nd4jLong *dScalarShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ);
/** /**
@ -252,105 +255,107 @@ static void execScalarInt(sd::LaunchContext *lc,
* @param dimensionLength * @param dimensionLength
*/ */
static void execBroadcast(sd::LaunchContext *lc, static void execBroadcast(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ); const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ);
static void execBroadcast(sd::LaunchContext* lc, static void execBroadcast(sd::LaunchContext* lc,
const int opNum, int opNum,
const void *hX, const Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo); void *dZ, const Nd4jLong *dZShapeInfo);
static void execInverseBroadcast(sd::LaunchContext *lc, static void execInverseBroadcast(sd::LaunchContext *lc,
int opNum, int opNum,
void *x, Nd4jLong *xShapeInfo, const void *x, const Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *result, Nd4jLong *resultShapeInfo, void *result, const Nd4jLong *resultShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
static void execBroadcastBool(sd::LaunchContext *lc, static void execBroadcastBool(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ); const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ);
static void execBroadcastBool(sd::LaunchContext* lc, const int opNum, static void execBroadcastBool(sd::LaunchContext* lc,
const void *hX, const Nd4jLong *hXShapeInfo, int opNum,
const void *dX, const Nd4jLong *dXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams); void *extraParams);
static void execInverseBroadcastBool(sd::LaunchContext *lc, static void execInverseBroadcastBool(sd::LaunchContext *lc,
int opNum, int opNum,
void *x, Nd4jLong *xShapeInfo, const void *x, const Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *result, Nd4jLong *resultShapeInfo, void *result, const Nd4jLong *resultShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
static void execBroadcastInt(sd::LaunchContext *lc, static void execBroadcastInt(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ); const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ);
static void execBroadcastInt(sd::LaunchContext* lc, const int opNum, static void execBroadcastInt(sd::LaunchContext* lc,
const void *hX, const Nd4jLong *hXShapeInfo, int opNum,
const void *dX, const Nd4jLong *dXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo); void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo);
static void execInverseBroadcastInt(sd::LaunchContext *lc, static void execInverseBroadcastInt(sd::LaunchContext *lc,
int opNum, int opNum,
void *x, Nd4jLong *xShapeInfo, const void *x, const Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *result, Nd4jLong *resultShapeInfo, void *result, const Nd4jLong *resultShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
/** /**
* *
@ -365,34 +370,34 @@ static void execScalarInt(sd::LaunchContext *lc,
* @param n * @param n
*/ */
static void execPairwiseTransform(sd::LaunchContext *lc, static void execPairwiseTransform(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams); void *extraParams);
static void execPairwiseBoolTransform(sd::LaunchContext *lc, static void execPairwiseBoolTransform(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams); void *extraParams);
static void execPairwiseIntTransform(sd::LaunchContext *lc, static void execPairwiseIntTransform(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams); void *extraParams);
/** /**
* *
@ -405,49 +410,50 @@ static void execScalarInt(sd::LaunchContext *lc,
* @param n * @param n
*/ */
static void execTransformFloat(sd::LaunchContext *lc, static void execTransformFloat(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
static void execTransformAny(sd::LaunchContext *lc, static void execTransformAny(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism = true); const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
bool allowParallelism = true);
static void execTransformStrict(sd::LaunchContext *lc, static void execTransformStrict(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
static void execTransformSame(sd::LaunchContext *lc, static void execTransformSame(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
static void execTransformBool(sd::LaunchContext *lc, static void execTransformBool(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
/** /**
* *
* @param opNum * @param opNum
@ -458,44 +464,44 @@ static void execTransformBool(sd::LaunchContext *lc,
* @param resultShapeInfo * @param resultShapeInfo
*/ */
static void execReduceFloat(sd::LaunchContext *lc, static void execReduceFloat(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
static void execReduceSame(sd::LaunchContext *lc, static void execReduceSame(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
static void execReduceBool(sd::LaunchContext *lc, static void execReduceBool(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
static void execReduceLong(sd::LaunchContext *lc, static void execReduceLong(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
/** /**
* *
@ -506,49 +512,49 @@ static void execTransformBool(sd::LaunchContext *lc,
* @return * @return
*/ */
static void execReduceFloatScalar(sd::LaunchContext *lc, static void execReduceFloatScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo); void *dZ, const Nd4jLong *dZShapeInfo);
static void execReduceBoolScalar(sd::LaunchContext *lc, static void execReduceBoolScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo); void *dZ, const Nd4jLong *dZShapeInfo);
static void execReduceSameScalar(sd::LaunchContext *lc, static void execReduceSameScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo); void *dZ, const Nd4jLong *dZShapeInfo);
static void execReduceLongScalar(sd::LaunchContext *lc, static void execReduceLongScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo); void *dZ, const Nd4jLong *dZShapeInfo);
static void execReduce3TAD(sd::LaunchContext *lc, static void execReduce3TAD(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yTadOffsets); const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffsets);
/** /**
* *
@ -562,15 +568,15 @@ static void execTransformBool(sd::LaunchContext *lc,
* @param dimensionLength * @param dimensionLength
*/ */
static void execSummaryStats(sd::LaunchContext *lc, static void execSummaryStats(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
bool biasCorrected); bool biasCorrected);
/** /**
* *
@ -582,13 +588,13 @@ static void execTransformBool(sd::LaunchContext *lc,
* @param resultShapeInfo * @param resultShapeInfo
*/ */
static void execSummaryStats(sd::LaunchContext *lc, static void execSummaryStats(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
bool biasCorrected); bool biasCorrected);
/** /**
* *
@ -600,68 +606,51 @@ static void execTransformBool(sd::LaunchContext *lc,
* @param resultShapeInfo * @param resultShapeInfo
*/ */
static void execSummaryStatsScalar(sd::LaunchContext *lc, static void execSummaryStatsScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
bool biasCorrected); bool biasCorrected);
static void execRandom(sd::LaunchContext *lc, static void execRandom(sd::LaunchContext *lc,
int opNum, int opNum,
Nd4jPointer state, Nd4jPointer state,
void *hZ, Nd4jLong *hZShapeBuffer, void *hZ, const Nd4jLong *hZShapeBuffer,
void *dZ, Nd4jLong *dZShapeBuffer, void *dZ, const Nd4jLong *dZShapeBuffer,
void *extraArguments); void *extraArguments);
static void execRandom(sd::LaunchContext *lc, static void execRandom(sd::LaunchContext *lc,
int opNum, int opNum,
Nd4jPointer state, Nd4jPointer state,
void *hX, Nd4jLong *hXShapeBuffer, const void *hX, const Nd4jLong *hXShapeBuffer,
void *dX, Nd4jLong *dXShapeBuffer, const void *dX, const Nd4jLong *dXShapeBuffer,
void *hZ, Nd4jLong *hZShapeBuffer, void *hZ, const Nd4jLong *hZShapeBuffer,
void *dZ, Nd4jLong *dZShapeBuffer, void *dZ, const Nd4jLong *dZShapeBuffer,
void *extraArguments); void *extraArguments);
static void execRandom(sd::LaunchContext *lc, static void execRandom(sd::LaunchContext *lc,
int opNum, int opNum,
Nd4jPointer state, Nd4jPointer state,
void *hX, Nd4jLong *hXShapeBuffer, const void *hX, const Nd4jLong *hXShapeBuffer,
void *dX, Nd4jLong *dXShapeBuffer, const void *dX, const Nd4jLong *dXShapeBuffer,
void *hY, Nd4jLong *hYShapeBuffer, const void *hY, const Nd4jLong *hYShapeBuffer,
void *dY, Nd4jLong *dYShapeBuffer, const void *dY, const Nd4jLong *dYShapeBuffer,
void *hZ, Nd4jLong *hZShapeBuffer, void *hZ, const Nd4jLong *hZShapeBuffer,
void *dZ, Nd4jLong *dZShapeBuffer, void *dZ, const Nd4jLong *dZShapeBuffer,
void *extraArguments); void *extraArguments);
template <typename X> inline static void execSort(void *x, const Nd4jLong *xShapeInfo, bool descending) {
static FORCEINLINE void execAggregate(sd::LaunchContext *lc,
int opNum,
void **varguments,
int numArguments,
Nd4jLong **shapeArguments,
int numShapeArguments,
int *indexArguments,
int numIndexArguments,
int **intArrays,
int numIntArrays,
void *vrealArguments,
int numRealArguments) {
}
inline static void execSort(void *x, Nd4jLong *xShapeInfo, bool descending) {
auto xType = sd::ArrayOptions::dataType(xShapeInfo); auto xType = sd::ArrayOptions::dataType(xShapeInfo);
BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::sortGeneric(x, xShapeInfo, descending), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::sortGeneric(x, xShapeInfo, descending), LIBND4J_TYPES);
} }
static void execSort(void *x, Nd4jLong *xShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool descending) { static void execSort(void *x, const Nd4jLong *xShapeInfo, int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, bool descending) {
auto xType = sd::ArrayOptions::dataType(xShapeInfo); auto xType = sd::ArrayOptions::dataType(xShapeInfo);
BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::sortTadGeneric(x, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::sortTadGeneric(x, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES);
@ -672,13 +661,13 @@ static void execTransformBool(sd::LaunchContext *lc,
} }
inline static Nd4jLong encodeBitmap(void *dx, Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold) { inline static Nd4jLong encodeBitmap(void *dx, const Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold) {
auto xType = sd::ArrayOptions::dataType(xShapeInfo); auto xType = sd::ArrayOptions::dataType(xShapeInfo);
BUILD_SINGLE_SELECTOR(xType, return sd::SpecialMethods, ::encodeBitmapGeneric(dx, xShapeInfo, N, dz, threshold), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(xType, return sd::SpecialMethods, ::encodeBitmapGeneric(dx, xShapeInfo, N, dz, threshold), FLOAT_TYPES);
} }
inline static void decodeBitmap(void *dx, Nd4jLong N, void *dz, Nd4jLong *zShapeInfo) { inline static void decodeBitmap(const void *dx, Nd4jLong N, void *dz, const Nd4jLong *zShapeInfo) {
auto zType = sd::ArrayOptions::dataType(zShapeInfo); auto zType = sd::ArrayOptions::dataType(zShapeInfo);
BUILD_SINGLE_SELECTOR(zType, sd::SpecialMethods, ::decodeBitmapGeneric(dx, N, dz, zShapeInfo), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(zType, sd::SpecialMethods, ::decodeBitmapGeneric(dx, N, dz, zShapeInfo), FLOAT_TYPES);

View File

@ -122,9 +122,9 @@ ND4J_EXPORT void setTADThreshold(int num);
*/ */
ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers, ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo);
/** /**
* *
@ -139,10 +139,10 @@ ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers, ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape);
/** /**
* *
@ -159,20 +159,20 @@ ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers,
ND4J_EXPORT void execBroadcast( ND4J_EXPORT void execBroadcast(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape);
ND4J_EXPORT void execBroadcastBool( ND4J_EXPORT void execBroadcastBool(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape);
/** /**
* *
@ -189,17 +189,17 @@ ND4J_EXPORT void execBroadcastBool(
ND4J_EXPORT void execPairwiseTransform( ND4J_EXPORT void execPairwiseTransform(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams); void *extraParams);
ND4J_EXPORT void execPairwiseTransformBool( ND4J_EXPORT void execPairwiseTransformBool(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams); void *extraParams);
/** /**
@ -213,28 +213,28 @@ ND4J_EXPORT void execPairwiseTransformBool(
*/ */
ND4J_EXPORT void execReduceFloat(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduceFloat(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo);
ND4J_EXPORT void execReduceSame(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduceSame(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo);
ND4J_EXPORT void execReduceBool(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduceBool(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo);
ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo);
/** /**
* *
@ -247,34 +247,34 @@ ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execReduceFloat2(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduceFloat2(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape);
ND4J_EXPORT void execReduceSame2(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduceSame2(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape);
ND4J_EXPORT void execReduceBool2(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduceBool2(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape);
ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape);
/** /**
* *
@ -289,10 +289,10 @@ ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo);
/** /**
* *
@ -305,10 +305,10 @@ ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo);
/** /**
* *
* @param opNum * @param opNum
@ -324,24 +324,24 @@ ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execReduce3Tad(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduce3Tad(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets); Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets);
ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong const* xTadShapeInfo, Nd4jLong const* xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets); Nd4jLong const* yTadShapeInfo, Nd4jLong const* yOffsets);
/** /**
* *
@ -356,16 +356,16 @@ ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execScalar(Nd4jPointer *extraPointers, ND4J_EXPORT void execScalar(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbScalar, Nd4jLong *hSscalarShapeInfo, Nd4jLong *dSscalarShapeInfo, OpaqueDataBuffer *dbScalar, Nd4jLong const* hSscalarShapeInfo, Nd4jLong const* dSscalarShapeInfo,
void *extraParams); void *extraParams);
ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers, ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbScalar, Nd4jLong *hSscalarShapeInfo, Nd4jLong *dSscalarShapeInfo, OpaqueDataBuffer *dbScalar, Nd4jLong const* hSscalarShapeInfo, Nd4jLong const* dSscalarShapeInfo,
void *extraParams); void *extraParams);
/** /**
@ -377,9 +377,9 @@ ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers, ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
bool biasCorrected); bool biasCorrected);
/** /**
* *
@ -392,9 +392,9 @@ ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers, ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
bool biasCorrected); bool biasCorrected);
/** /**
* *
@ -409,12 +409,12 @@ ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers, ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
bool biasCorrected, bool biasCorrected,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets);
/** /**
* *
@ -428,32 +428,32 @@ ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execTransformFloat(Nd4jPointer *extraPointers, ND4J_EXPORT void execTransformFloat(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams); void *extraParams);
ND4J_EXPORT void execTransformSame(Nd4jPointer *extraPointers, ND4J_EXPORT void execTransformSame(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams); void *extraParams);
ND4J_EXPORT void execTransformBool(Nd4jPointer *extraPointers, ND4J_EXPORT void execTransformBool(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams); void *extraParams);
ND4J_EXPORT void execTransformAny(Nd4jPointer *extraPointers, ND4J_EXPORT void execTransformAny(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams); void *extraParams);
ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers, ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams); void *extraParams);
/** /**
@ -471,23 +471,23 @@ ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execScalarTad(Nd4jPointer *extraPointers, ND4J_EXPORT void execScalarTad(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ);
ND4J_EXPORT void execScalarBoolTad(Nd4jPointer *extraPointers, ND4J_EXPORT void execScalarBoolTad(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ);
ND4J_EXPORT void specialConcat ( ND4J_EXPORT void specialConcat (
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
@ -496,7 +496,7 @@ ND4J_EXPORT void specialConcat (
Nd4jPointer *data, Nd4jPointer *data,
Nd4jPointer *inputShapeInfo, Nd4jPointer *inputShapeInfo,
void *result, void *result,
Nd4jLong *resultShapeInfo, Nd4jLong const* resultShapeInfo,
Nd4jPointer *tadPointers, Nd4jPointer *tadPointers,
Nd4jPointer *offsetPointers); Nd4jPointer *offsetPointers);
@ -792,14 +792,14 @@ typedef sd::TadPack OpaqueTadPack;
* @param targetBuffer * @param targetBuffer
* @param offsetsBuffer * @param offsetsBuffer
*/ */
ND4J_EXPORT OpaqueTadPack* tadOnlyShapeInfo(Nd4jLong *xShapeInfo, ND4J_EXPORT OpaqueTadPack* tadOnlyShapeInfo(Nd4jLong const*xShapeInfo,
int *dimension, int *dimension,
int dimensionLength); int dimensionLength);
ND4J_EXPORT Nd4jLong* getPrimaryShapeInfo(OpaqueTadPack* pack); ND4J_EXPORT Nd4jLong const* getPrimaryShapeInfo(OpaqueTadPack* pack);
ND4J_EXPORT Nd4jLong* getPrimaryOffsets(OpaqueTadPack* pack); ND4J_EXPORT Nd4jLong const* getPrimaryOffsets(OpaqueTadPack* pack);
ND4J_EXPORT Nd4jLong* getSpecialShapeInfo(OpaqueTadPack* pack); ND4J_EXPORT Nd4jLong const* getSpecialShapeInfo(OpaqueTadPack* pack);
ND4J_EXPORT Nd4jLong* getSpecialOffsets(OpaqueTadPack* pack); ND4J_EXPORT Nd4jLong const* getSpecialOffsets(OpaqueTadPack* pack);
ND4J_EXPORT Nd4jLong getNumberOfTads(OpaqueTadPack* pack); ND4J_EXPORT Nd4jLong getNumberOfTads(OpaqueTadPack* pack);
ND4J_EXPORT int getShapeInfoLength(OpaqueTadPack* pack); ND4J_EXPORT int getShapeInfoLength(OpaqueTadPack* pack);
@ -824,14 +824,14 @@ ND4J_EXPORT void deleteTadPack(OpaqueTadPack* ptr);
* @param zTadOffsets * @param zTadOffsets
*/ */
ND4J_EXPORT void pullRows(Nd4jPointer *extraPointers, ND4J_EXPORT void pullRows(Nd4jPointer *extraPointers,
OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* xShapeInfo, Nd4jLong const* dxShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *zShapeInfo, Nd4jLong *dzShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* zShapeInfo, Nd4jLong const* dzShapeInfo,
Nd4jLong n, Nd4jLong n,
Nd4jLong *indexes, Nd4jLong *indexes,
Nd4jLong *tadShapeInfo, Nd4jLong const* tadShapeInfo,
Nd4jLong *tadOffsets, Nd4jLong const* tadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong const* zTadShapeInfo,
Nd4jLong *zTadOffsets); Nd4jLong const* zTadOffsets);
/** /**
* *
@ -843,20 +843,20 @@ ND4J_EXPORT void pullRows(Nd4jPointer *extraPointers,
* @param propagate * @param propagate
*/ */
ND4J_EXPORT void average(Nd4jPointer *extras, ND4J_EXPORT void average(Nd4jPointer *extras,
Nd4jPointer *x, Nd4jLong *xShapeInfo, Nd4jPointer *x, Nd4jLong const* xShapeInfo,
Nd4jPointer *dx, Nd4jLong *dxShapeInfo, Nd4jPointer *dx, Nd4jLong const* dxShapeInfo,
void *z, Nd4jLong *zShapeInfo, void *z, Nd4jLong const* zShapeInfo,
void *dz, Nd4jLong *dzShapeInfo, void *dz, Nd4jLong const* dzShapeInfo,
int n, int n,
Nd4jLong length, Nd4jLong length,
bool propagate); bool propagate);
ND4J_EXPORT void accumulate(Nd4jPointer *extras, ND4J_EXPORT void accumulate(Nd4jPointer *extras,
Nd4jPointer *x, Nd4jLong *xShapeInfo, Nd4jPointer *x, Nd4jLong const* xShapeInfo,
Nd4jPointer *dx, Nd4jLong *dxShapeInfo, Nd4jPointer *dx, Nd4jLong const* dxShapeInfo,
void *z, Nd4jLong *zShapeInfo, void *z, Nd4jLong const* zShapeInfo,
void *dz, Nd4jLong *dzShapeInfo, void *dz, Nd4jLong const* dzShapeInfo,
int n, int n,
Nd4jLong length); Nd4jLong length);
@ -1004,7 +1004,7 @@ ND4J_EXPORT void execAggregateBatch(Nd4jPointer *extraPointers,
ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers, ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers,
int opNum, int opNum,
Nd4jPointer state, Nd4jPointer state,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeBuffer, Nd4jLong const* dZShapeBuffer,
void *extraArguments); void *extraArguments);
/** /**
@ -1023,9 +1023,9 @@ ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers,
ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers, ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers,
int opNum, int opNum,
Nd4jPointer state, Nd4jPointer state,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeBuffer, Nd4jLong *dXShapeBuffer, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeBuffer, Nd4jLong const* dXShapeBuffer,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeBuffer, Nd4jLong *dYShapeBuffer, OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeBuffer, Nd4jLong const* dYShapeBuffer,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeBuffer, Nd4jLong const* dZShapeBuffer,
void *extraArguments); void *extraArguments);
/** /**
@ -1042,8 +1042,8 @@ ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers,
ND4J_EXPORT void execRandom2(Nd4jPointer *extraPointers, ND4J_EXPORT void execRandom2(Nd4jPointer *extraPointers,
int opNum, int opNum,
Nd4jPointer state, Nd4jPointer state,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeBuffer, Nd4jLong *dXShapeBuffer, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeBuffer, Nd4jLong const* dXShapeBuffer,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeBuffer, Nd4jLong const* dZShapeBuffer,
void *extraArguments); void *extraArguments);
@ -1098,11 +1098,11 @@ ND4J_EXPORT void destroyRandom(Nd4jPointer ptrRandom);
*/ */
template <typename T> template <typename T>
static Nd4jPointer _numpyHeaderForNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize,Nd4jLong *headerSize) { static Nd4jPointer _numpyHeaderForNd4j(Nd4jPointer data,const Nd4jPointer shapeBuffer,Nd4jLong wordSize,Nd4jLong* headerSize) {
Nd4jLong *shapeBufferCast = reinterpret_cast<Nd4jLong *>(shapeBuffer); Nd4jLong const* shapeBufferCast = reinterpret_cast<const Nd4jLong *>(shapeBuffer);
int rank = shape::rank(shapeBufferCast); int rank = shape::rank(shapeBufferCast);
Nd4jLong *shape = shape::shapeOf(shapeBufferCast); const Nd4jLong* shape = shape::shapeOf(shapeBufferCast);
unsigned int *npShape = new unsigned int[rank]; unsigned int* npShape = new unsigned int[rank];
for(int i = 0; i < rank; i++) { for(int i = 0; i < rank; i++) {
npShape[i] = shape[i]; npShape[i] = shape[i];
} }
@ -1125,7 +1125,7 @@ static Nd4jPointer _numpyHeaderForNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,
extern "C" { extern "C" {
static Nd4jPointer numpyHeaderForNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize,Nd4jLong *headerSize) { static Nd4jPointer numpyHeaderForNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize,Nd4jLong* headerSize) {
auto shapeBufferCast = reinterpret_cast<Nd4jLong *>(shapeBuffer); auto shapeBufferCast = reinterpret_cast<Nd4jLong *>(shapeBuffer);
auto type = sd::ArrayOptions::dataType(shapeBufferCast); auto type = sd::ArrayOptions::dataType(shapeBufferCast);
BUILD_SINGLE_SELECTOR(type, return _numpyHeaderForNd4j, (data, shapeBuffer, wordSize, headerSize), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(type, return _numpyHeaderForNd4j, (data, shapeBuffer, wordSize, headerSize), LIBND4J_TYPES);
@ -1427,53 +1427,53 @@ ND4J_EXPORT Nd4jPointer pointerForAddress(Nd4jLong address);
* @return * @return
*/ */
ND4J_EXPORT void tear(Nd4jPointer *extraPointers, ND4J_EXPORT void tear(Nd4jPointer *extraPointers,
OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* xShapeInfo, Nd4jLong const* dxShapeInfo,
Nd4jPointer *targets, Nd4jLong *zShapeInfo, Nd4jPointer *targets, Nd4jLong const* zShapeInfo,
Nd4jLong *tadShapeInfo, Nd4jLong const* tadShapeInfo,
Nd4jLong *tadOffsets); Nd4jLong const* tadOffsets);
ND4J_EXPORT void sort(Nd4jPointer *extraPointers, ND4J_EXPORT void sort(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, void *x, Nd4jLong const* xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo, void *dx, Nd4jLong const* dxShapeInfo,
bool descending); bool descending);
ND4J_EXPORT void sortByKey(Nd4jPointer *extraPointers, ND4J_EXPORT void sortByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, void *x, Nd4jLong const* xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo, void *dx, Nd4jLong const* dxShapeInfo,
void *y, Nd4jLong *yShapeInfo, void *y, Nd4jLong const* yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo, void *dy, Nd4jLong const* dyShapeInfo,
bool descending); bool descending);
ND4J_EXPORT void sortByValue(Nd4jPointer *extraPointers, ND4J_EXPORT void sortByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, void *x, Nd4jLong const* xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo, void *dx, Nd4jLong const* dxShapeInfo,
void *y, Nd4jLong *yShapeInfo, void *y, Nd4jLong const* yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo, void *dy, Nd4jLong const* dyShapeInfo,
bool descending); bool descending);
ND4J_EXPORT void sortTad(Nd4jPointer *extraPointers, ND4J_EXPORT void sortTad(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, void *x, Nd4jLong const* xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo, void *dx, Nd4jLong const* dxShapeInfo,
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong const* tadShapeInfo,
Nd4jLong *tadOffsets, Nd4jLong const* tadOffsets,
bool descending); bool descending);
ND4J_EXPORT void sortTadByKey(Nd4jPointer *extraPointers, ND4J_EXPORT void sortTadByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, void *x, Nd4jLong const* xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo, void *dx, Nd4jLong const* dxShapeInfo,
void *y, Nd4jLong *yShapeInfo, void *y, Nd4jLong const* yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo, void *dy, Nd4jLong const* dyShapeInfo,
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
bool descending); bool descending);
ND4J_EXPORT void sortTadByValue(Nd4jPointer *extraPointers, ND4J_EXPORT void sortTadByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, void *x, Nd4jLong const* xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo, void *dx, Nd4jLong const* dxShapeInfo,
void *y, Nd4jLong *yShapeInfo, void *y, Nd4jLong const* yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo, void *dy, Nd4jLong const* dyShapeInfo,
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
bool descending); bool descending);
@ -1509,7 +1509,7 @@ ND4J_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, N
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs); ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs);
ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list); ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list);
ND4J_EXPORT Nd4jLong* getShape(OpaqueShapeList* list, Nd4jLong i); ND4J_EXPORT Nd4jLong const* getShape(OpaqueShapeList* list, Nd4jLong i);
ND4J_EXPORT void deleteShapeList(Nd4jPointer shapeList); ND4J_EXPORT void deleteShapeList(Nd4jPointer shapeList);
@ -1526,7 +1526,7 @@ ND4J_EXPORT OpaqueVariable* getVariable(OpaqueVariablesSet* set, Nd4jLong i);
ND4J_EXPORT int getVariableId(OpaqueVariable* variable); ND4J_EXPORT int getVariableId(OpaqueVariable* variable);
ND4J_EXPORT int getVariableIndex(OpaqueVariable* variable); ND4J_EXPORT int getVariableIndex(OpaqueVariable* variable);
ND4J_EXPORT const char* getVariableName(OpaqueVariable* variable); ND4J_EXPORT const char* getVariableName(OpaqueVariable* variable);
ND4J_EXPORT Nd4jLong* getVariableShape(OpaqueVariable* variable); ND4J_EXPORT Nd4jLong const* getVariableShape(OpaqueVariable* variable);
ND4J_EXPORT void* getVariableBuffer(OpaqueVariable* variable); ND4J_EXPORT void* getVariableBuffer(OpaqueVariable* variable);
ND4J_EXPORT int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId); ND4J_EXPORT int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId);
@ -1545,7 +1545,7 @@ ND4J_EXPORT void deleteGraphState(Nd4jPointer state);
ND4J_EXPORT void deleteResultWrapper(Nd4jPointer ptr); ND4J_EXPORT void deleteResultWrapper(Nd4jPointer ptr);
ND4J_EXPORT int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer x, Nd4jLong *xShapeInfo, int N, float threshold); ND4J_EXPORT int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer x, Nd4jLong const* xShapeInfo, int N, float threshold);
// this method executes op that requires scope to be present: if/while/cond/whatever // this method executes op that requires scope to be present: if/while/cond/whatever
ND4J_EXPORT Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs); ND4J_EXPORT Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs);
@ -1557,11 +1557,11 @@ ND4J_EXPORT char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer pt
ND4J_EXPORT void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr); ND4J_EXPORT void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr);
ND4J_EXPORT void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, ND4J_EXPORT void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs,
void* hX, Nd4jLong* hXShapeInfo, Nd4jLong* hXOffsets, void* hX, Nd4jLong const* hXShapeInfo, Nd4jLong const* hXOffsets,
void* dX, Nd4jLong* dXShapeInfo, Nd4jLong* dXOffsets, void* dX, Nd4jLong const* dXShapeInfo, Nd4jLong const* dXOffsets,
void* hY, Nd4jLong* hYShapeInfo, Nd4jLong* hYOffsets, void* hY, Nd4jLong const* hYShapeInfo, Nd4jLong const* hYOffsets,
void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets, void* dY, Nd4jLong const* dYShapeInfo, Nd4jLong const* dYOffsets,
void* hIindexes, Nd4jLong* hIndicesShapeInfo, void* dIindexes, Nd4jLong* dIndicesShapeInfo); void* hIindexes, Nd4jLong const* hIndicesShapeInfo, void* dIindexes, Nd4jLong const* dIndicesShapeInfo);
ND4J_EXPORT void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo); ND4J_EXPORT void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo);
@ -1570,7 +1570,7 @@ typedef sd::ConstantDataBuffer OpaqueConstantDataBuffer;
ND4J_EXPORT OpaqueConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, sd::DataType dtype, char order, Nd4jLong ews, bool empty); ND4J_EXPORT OpaqueConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, sd::DataType dtype, char order, Nd4jLong ews, bool empty);
ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferLong(sd::DataType dtype, Nd4jLong *data, int length); ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferLong(sd::DataType dtype, Nd4jLong const* data, int length);
ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferDouble(sd::DataType dtype, double *data, int length); ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferDouble(sd::DataType dtype, double *data, int length);
ND4J_EXPORT OpaqueConstantDataBuffer* constantBuffer(sd::DataType dtype, sd::ConstantDescriptor *descriptor); ND4J_EXPORT OpaqueConstantDataBuffer* constantBuffer(sd::DataType dtype, sd::ConstantDescriptor *descriptor);

View File

@ -77,11 +77,11 @@
* @param hZShapeInfo * @param hZShapeInfo
*/ */
void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc, int opNum, void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, const Nd4jLong *dZShapeInfo) {
@ -106,22 +106,21 @@ void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc, int opNu
*/ */
void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc, void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo); auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
Nd4jLong* hz = reinterpret_cast<Nd4jLong*>(hZ); auto hz = reinterpret_cast<Nd4jLong*>(hZ);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::indexreduce::IndexReduce, ::exec(opNum, hX, hXShapeInfo, extraParams, hz, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES, INDEXING_TYPES); BUILD_DOUBLE_SELECTOR(xType, zType, functions::indexreduce::IndexReduce, ::exec(opNum, hX, hXShapeInfo, extraParams, hz, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES, INDEXING_TYPES);
// BUILD_SINGLE_SELECTOR(xType, functions::indexreduce::IndexReduce, ::exec(opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -139,16 +138,16 @@ void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc,
*/ */
void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc, void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) {
@ -230,15 +229,15 @@ void NativeOpExecutioner::execBroadcast(sd::LaunchContext* lc, const int opNum,
void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc, void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hYShapeInfo); auto yType = sd::ArrayOptions::dataType(hYShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo); auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -269,17 +268,17 @@ void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc, void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) {
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
@ -320,17 +319,17 @@ void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext* lc, const int opN
void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc, void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hYShapeInfo); auto yType = sd::ArrayOptions::dataType(hYShapeInfo);
@ -358,16 +357,16 @@ void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc, void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -422,16 +421,16 @@ void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc, const int opN
} }
void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc, void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hYShapeInfo); auto yType = sd::ArrayOptions::dataType(hYShapeInfo);
@ -471,14 +470,14 @@ void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc,
* @param n * @param n
*/ */
void NativeOpExecutioner::execPairwiseTransform(sd::LaunchContext *lc, void NativeOpExecutioner::execPairwiseTransform(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams) { void *extraParams) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hYShapeInfo); auto yType = sd::ArrayOptions::dataType(hYShapeInfo);
@ -504,14 +503,14 @@ void NativeOpExecutioner::execPairwiseTransform(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execPairwiseBoolTransform(sd::LaunchContext *lc, void NativeOpExecutioner::execPairwiseBoolTransform(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams) { void *extraParams) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -538,14 +537,14 @@ void NativeOpExecutioner::execPairwiseBoolTransform(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execPairwiseIntTransform(sd::LaunchContext *lc, void NativeOpExecutioner::execPairwiseIntTransform(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams) { void *extraParams) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hYShapeInfo); auto yType = sd::ArrayOptions::dataType(hYShapeInfo);
@ -580,14 +579,14 @@ void NativeOpExecutioner::execPairwiseIntTransform(sd::LaunchContext *lc,
* @param hZShapeInfo * @param hZShapeInfo
*/ */
void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc, void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
@ -609,14 +608,14 @@ void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc, void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -637,14 +636,14 @@ void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc, void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -665,14 +664,14 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc, void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -701,12 +700,12 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
* @return * @return
*/ */
void NativeOpExecutioner::execReduceFloatScalar(sd::LaunchContext *lc, void NativeOpExecutioner::execReduceFloatScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, const Nd4jLong *dZShapeInfo) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -717,12 +716,12 @@ void NativeOpExecutioner::execReduceFloatScalar(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceSameScalar(sd::LaunchContext *lc, void NativeOpExecutioner::execReduceSameScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, const Nd4jLong *dZShapeInfo) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -732,14 +731,12 @@ void NativeOpExecutioner::execReduceSameScalar(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceBoolScalar(sd::LaunchContext *lc, void NativeOpExecutioner::execReduceBoolScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, const Nd4jLong *dZShapeInfo) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo); auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -749,13 +746,12 @@ void NativeOpExecutioner::execReduceBoolScalar(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceLongScalar(sd::LaunchContext *lc, void NativeOpExecutioner::execReduceLongScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, const Nd4jLong *dZShapeInfo) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo); auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -779,14 +775,15 @@ void NativeOpExecutioner::execReduceLongScalar(sd::LaunchContext *lc,
* @param dimensionLength * @param dimensionLength
*/ */
void NativeOpExecutioner::execReduce3Scalar(sd::LaunchContext *lc, void NativeOpExecutioner::execReduce3Scalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, const Nd4jLong *dZShapeInfo) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo); auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -807,15 +804,14 @@ void NativeOpExecutioner::execReduce3Scalar(sd::LaunchContext *lc,
* @param hZShapeInfo * @param hZShapeInfo
*/ */
void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc, void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, const Nd4jLong *dZShapeInfo) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo); auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -826,17 +822,17 @@ void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc, void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadOnlyShapeInfo, Nd4jLong *xTadOffsets, const Nd4jLong *xTadOnlyShapeInfo, const Nd4jLong *xTadOffsets,
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets) { const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -867,18 +863,17 @@ void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc, void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) { const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo); auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -895,19 +890,17 @@ void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduce3TAD(sd::LaunchContext *lc, void NativeOpExecutioner::execReduce3TAD(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yTadOffsets) { const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo); auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -948,14 +941,15 @@ void NativeOpExecutioner::execReduce3TAD(sd::LaunchContext *lc,
* @param n * @param n
*/ */
void NativeOpExecutioner::execScalar(sd::LaunchContext *lc, void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hScalarShapeInfo, const void *hScalar, const Nd4jLong *hScalarShapeInfo,
void *dScalar, Nd4jLong *dScalarShapeInfo, const void *dScalar, const Nd4jLong *dScalarShapeInfo,
void *extraParams, bool allowParallelism) { void *extraParams,
bool allowParallelism) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo);
@ -983,16 +977,16 @@ void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalar(sd::LaunchContext *lc, void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const*hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const*dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const*hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const*dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo, void const* hScalars, Nd4jLong const*hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo, void const* dScalars, Nd4jLong const*dScalarShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const*tadShapeInfo, Nd4jLong const*tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong const*tadShapeInfoZ, Nd4jLong const*tadOffsetsZ) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo);
@ -1019,14 +1013,15 @@ void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc, void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hSscalarShapeInfo, const void *hScalar, const Nd4jLong *hSscalarShapeInfo,
void *dScalar, Nd4jLong *dSscalarShapeInfo, const void *dScalar, const Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism) { void *extraParams,
bool allowParallelism) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hSscalarShapeInfo); auto yType = sd::ArrayOptions::dataType(hSscalarShapeInfo);
@ -1052,17 +1047,17 @@ void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc, void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo, const void *hScalars, const Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo, const void *dScalars, const Nd4jLong *dScalarShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo);
@ -1087,14 +1082,15 @@ void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc, void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hSscalarShapeInfo, const void *hScalar, const Nd4jLong *hSscalarShapeInfo,
void *dScalar, Nd4jLong *dSscalarShapeInfo, const void *dScalar, const Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism) { void *extraParams,
bool allowParallelism) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hSscalarShapeInfo); auto yType = sd::ArrayOptions::dataType(hSscalarShapeInfo);
@ -1120,17 +1116,17 @@ void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc, void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo, const void *hScalars, const Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo, const void *dScalars, const Nd4jLong *dScalarShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo);
@ -1164,13 +1160,13 @@ void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
* @param hZShapeInfo * @param hZShapeInfo
*/ */
void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc, void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
bool biasCorrected) { bool biasCorrected) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -1190,13 +1186,13 @@ void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
* @param hZShapeInfo * @param hZShapeInfo
*/ */
void NativeOpExecutioner::execSummaryStatsScalar(sd::LaunchContext *lc, void NativeOpExecutioner::execSummaryStatsScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
bool biasCorrected) { bool biasCorrected) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -1218,15 +1214,15 @@ void NativeOpExecutioner::execSummaryStatsScalar(sd::LaunchContext *lc,
* @param dimensionLength * @param dimensionLength
*/ */
void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc, void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
bool biasCorrected) { bool biasCorrected) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo); auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -1246,13 +1242,13 @@ void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
* @param n * @param n
*/ */
void NativeOpExecutioner::execTransformFloat(sd::LaunchContext *lc, void NativeOpExecutioner::execTransformFloat(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo); auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -1268,13 +1264,13 @@ void NativeOpExecutioner::execTransformFloat(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execTransformBool(sd::LaunchContext *lc, void NativeOpExecutioner::execTransformBool(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo); auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -1290,13 +1286,14 @@ void NativeOpExecutioner::execTransformBool(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc, void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) { const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
bool allowParallelism) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo); auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -1319,13 +1316,13 @@ void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execTransformSame(sd::LaunchContext *lc, void NativeOpExecutioner::execTransformSame(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo); auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -1341,13 +1338,13 @@ void NativeOpExecutioner::execTransformSame(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execTransformStrict(sd::LaunchContext *lc, void NativeOpExecutioner::execTransformStrict(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo); auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -1363,11 +1360,11 @@ void NativeOpExecutioner::execTransformStrict(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
int opNum, int opNum,
Nd4jPointer state, Nd4jPointer state,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraArguments) { void *extraArguments) {
auto zType = sd::ArrayOptions::dataType(hZShapeInfo); auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -1380,14 +1377,13 @@ void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
int opNum, int opNum,
Nd4jPointer state, Nd4jPointer state,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraArguments) { void *extraArguments) {
auto zType = sd::ArrayOptions::dataType(hZShapeInfo); auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@ -1399,16 +1395,15 @@ void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
int opNum, int opNum,
Nd4jPointer state, Nd4jPointer state,
void *hX, Nd4jLong *hXShapeInfo, const void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, const void *hY, const Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo,
void *extraArguments) { void *extraArguments) {
auto xType = sd::ArrayOptions::dataType(hZShapeInfo); auto xType = sd::ArrayOptions::dataType(hZShapeInfo);

View File

@ -102,9 +102,9 @@ void setTADThreshold(int num) {
*/ */
void execIndexReduceScalar(Nd4jPointer *extraPointers, void execIndexReduceScalar(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
try { try {
NativeOpExecutioner::execIndexReduceScalar(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); NativeOpExecutioner::execIndexReduceScalar(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo);
} catch (std::exception &e) { } catch (std::exception &e) {
@ -125,10 +125,10 @@ void execIndexReduceScalar(Nd4jPointer *extraPointers,
* @param dimensionLength * @param dimensionLength
*/ */
void execIndexReduce(Nd4jPointer *extraPointers,int opNum, void execIndexReduce(Nd4jPointer *extraPointers,int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
try { try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary()); auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape)); int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
@ -176,18 +176,16 @@ void execIndexReduce(Nd4jPointer *extraPointers,int opNum,
*/ */
void execBroadcast(Nd4jPointer *extraPointers, void execBroadcast(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
try { try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary()); auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape)); auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
dimensionLength); auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(hZShapeInfo, dimension, dimensionLength);
auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(hZShapeInfo, dimension,
dimensionLength);
auto hTADShapeInfo = tadPackX.primaryShapeInfo(); auto hTADShapeInfo = tadPackX.primaryShapeInfo();
auto hTADOffsets = tadPackX.primaryOffsets(); auto hTADOffsets = tadPackX.primaryOffsets();
@ -216,19 +214,17 @@ void execBroadcast(Nd4jPointer *extraPointers,
void execBroadcastBool(Nd4jPointer *extraPointers, void execBroadcastBool(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
try { try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary()); auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape)); auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
dimensionLength); auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(hZShapeInfo, dimension, dimensionLength);
auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(hZShapeInfo, dimension,
dimensionLength);
auto hTADShapeInfo = tadPackX.primaryShapeInfo(); auto hTADShapeInfo = tadPackX.primaryShapeInfo();
auto hTADOffsets = tadPackX.primaryOffsets(); auto hTADOffsets = tadPackX.primaryOffsets();
@ -272,9 +268,9 @@ void execBroadcastBool(Nd4jPointer *extraPointers,
void execPairwiseTransform( void execPairwiseTransform(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraParams) { void *extraParams) {
try { try {
NativeOpExecutioner::execPairwiseTransform(nullptr, NativeOpExecutioner::execPairwiseTransform(nullptr,
@ -301,9 +297,9 @@ void execPairwiseTransform(
void execPairwiseTransformBool( void execPairwiseTransformBool(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraParams) { void *extraParams) {
try { try {
@ -340,9 +336,9 @@ void execPairwiseTransformBool(
void execReduceFloat( void execReduceFloat(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
try { try {
NativeOpExecutioner::execReduceFloatScalar(nullptr, NativeOpExecutioner::execReduceFloatScalar(nullptr,
@ -365,9 +361,9 @@ void execReduceFloat(
void execReduceSame( void execReduceSame(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
try { try {
NativeOpExecutioner::execReduceSameScalar(nullptr, NativeOpExecutioner::execReduceSameScalar(nullptr,
@ -390,9 +386,9 @@ void execReduceSame(
void execReduceBool( void execReduceBool(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
try { try {
NativeOpExecutioner::execReduceBoolScalar(nullptr, NativeOpExecutioner::execReduceBoolScalar(nullptr,
opNum, opNum,
@ -414,9 +410,9 @@ void execReduceBool(
void execReduceLong( void execReduceLong(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
try { try {
NativeOpExecutioner::execReduceLongScalar(nullptr, NativeOpExecutioner::execReduceLongScalar(nullptr,
opNum, opNum,
@ -446,16 +442,15 @@ void execReduceLong(
*/ */
void execReduceFloat2(Nd4jPointer *extraPointers, void execReduceFloat2(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
try { try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary()); auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape)); auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
dimensionLength);
auto hTADShapeInfo = tadPackX.primaryShapeInfo(); auto hTADShapeInfo = tadPackX.primaryShapeInfo();
auto hTADOffsets = tadPackX.primaryOffsets(); auto hTADOffsets = tadPackX.primaryOffsets();
@ -482,13 +477,13 @@ void execReduceFloat2(Nd4jPointer *extraPointers,
void execReduceBool2(Nd4jPointer *extraPointers, void execReduceBool2(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
try { try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary()); auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape)); auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension,
dimensionLength); dimensionLength);
@ -518,10 +513,10 @@ void execReduceBool2(Nd4jPointer *extraPointers,
void execReduceSame2(Nd4jPointer *extraPointers, void execReduceSame2(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
try { try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary()); auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape)); int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
@ -554,16 +549,15 @@ void execReduceSame2(Nd4jPointer *extraPointers,
void execReduceLong2(Nd4jPointer *extraPointers, void execReduceLong2(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
try { try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary()); auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape)); int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
dimensionLength);
auto hTADShapeInfo = tadPack.primaryShapeInfo(); auto hTADShapeInfo = tadPack.primaryShapeInfo();
auto hTADOffsets = tadPack.primaryOffsets(); auto hTADOffsets = tadPack.primaryOffsets();
@ -601,10 +595,10 @@ void execReduceLong2(Nd4jPointer *extraPointers,
*/ */
void execReduce3(Nd4jPointer *extraPointers, void execReduce3(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
try { try {
NativeOpExecutioner::execReduce3(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo, NativeOpExecutioner::execReduce3(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo,
dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo);
@ -624,10 +618,10 @@ void execReduce3(Nd4jPointer *extraPointers,
* @param hYShapeInfo * @param hYShapeInfo
*/ */
void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum, void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
try { try {
NativeOpExecutioner::execReduce3Scalar(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbY->primary(), NativeOpExecutioner::execReduce3Scalar(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbY->primary(),
hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo);
@ -651,16 +645,16 @@ void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum,
*/ */
void execReduce3Tad(Nd4jPointer *extraPointers, void execReduce3Tad(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets) { const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets) {
try { try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary()); auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape)); auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
if (extraPointers == nullptr || extraPointers[2] == 0) { if (extraPointers == nullptr || extraPointers[2] == 0) {
NativeOpExecutioner::execReduce3(LaunchContext::defaultContext(), opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, NativeOpExecutioner::execReduce3(LaunchContext::defaultContext(), opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo,
@ -704,9 +698,9 @@ bool isBlasVersionMatches(int major, int minor, int build) {
void execScalar( void execScalar(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalar, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, OpaqueDataBuffer *dbScalar, const Nd4jLong *hScalarShapeInfo, const Nd4jLong *dScalarShapeInfo,
void *extraParams) { void *extraParams) {
try { try {
NativeOpExecutioner::execScalar(nullptr, NativeOpExecutioner::execScalar(nullptr,
@ -733,9 +727,9 @@ void execScalar(
void execScalarBool( void execScalarBool(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalar, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, OpaqueDataBuffer *dbScalar, const Nd4jLong *hScalarShapeInfo, const Nd4jLong *dScalarShapeInfo,
void *extraParams) { void *extraParams) {
try { try {
NativeOpExecutioner::execScalarBool(nullptr, NativeOpExecutioner::execScalarBool(nullptr,
@ -768,9 +762,9 @@ void execScalarBool(
*/ */
void execSummaryStatsScalar(Nd4jPointer *extraPointers, void execSummaryStatsScalar(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
bool biasCorrected) { bool biasCorrected) {
try { try {
NativeOpExecutioner::execSummaryStatsScalar(nullptr, NativeOpExecutioner::execSummaryStatsScalar(nullptr,
@ -801,9 +795,9 @@ void execSummaryStatsScalar(Nd4jPointer *extraPointers,
*/ */
void execSummaryStats(Nd4jPointer *extraPointers, void execSummaryStats(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
bool biasCorrected) { bool biasCorrected) {
try { try {
NativeOpExecutioner::execSummaryStats(nullptr, NativeOpExecutioner::execSummaryStats(nullptr,
@ -836,12 +830,12 @@ void execSummaryStats(Nd4jPointer *extraPointers,
*/ */
void execSummaryStatsTad(Nd4jPointer *extraPointers, void execSummaryStatsTad(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape,
bool biasCorrected, bool biasCorrected,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
try { try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary()); auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape)); int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
@ -882,8 +876,8 @@ void execSummaryStatsTad(Nd4jPointer *extraPointers,
void execTransformFloat( void execTransformFloat(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraParams) { void *extraParams) {
try { try {
NativeOpExecutioner::execTransformFloat(nullptr, NativeOpExecutioner::execTransformFloat(nullptr,
@ -908,8 +902,8 @@ void execTransformFloat(
void execTransformSame( void execTransformSame(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraParams) { void *extraParams) {
try { try {
NativeOpExecutioner::execTransformSame(nullptr, NativeOpExecutioner::execTransformSame(nullptr,
@ -934,8 +928,8 @@ void execTransformSame(
void execTransformBool( void execTransformBool(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraParams) { void *extraParams) {
try { try {
NativeOpExecutioner::execTransformBool(nullptr, NativeOpExecutioner::execTransformBool(nullptr,
@ -960,8 +954,8 @@ void execTransformBool(
void execTransformAny( void execTransformAny(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraParams) { void *extraParams) {
try { try {
NativeOpExecutioner::execTransformAny(nullptr, NativeOpExecutioner::execTransformAny(nullptr,
@ -986,8 +980,8 @@ void execTransformAny(
void execTransformStrict( void execTransformStrict(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraParams) { void *extraParams) {
try { try {
NativeOpExecutioner::execTransformStrict(nullptr, NativeOpExecutioner::execTransformStrict(nullptr,
@ -1011,19 +1005,17 @@ void execTransformStrict(
void execReduce3All(Nd4jPointer *extraPointers, void execReduce3All(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape,
Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
Nd4jLong *xOffsets, const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets) {
Nd4jLong *yTadShapeInfo,
Nd4jLong *yOffsets) {
try { try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary()); auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape)); auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
NativeOpExecutioner::execReduce3All(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParamsVals, dbY->primary(), NativeOpExecutioner::execReduce3All(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParamsVals, dbY->primary(),
@ -1046,7 +1038,7 @@ void specialConcat(
Nd4jPointer *data, Nd4jPointer *data,
Nd4jPointer *inputShapeInfo, Nd4jPointer *inputShapeInfo,
void *hZ, void *hZ,
Nd4jLong *hZShapeInfo, Nd4jLong const* hZShapeInfo,
Nd4jPointer *tadPointers, Nd4jPointer *tadPointers,
Nd4jPointer *offsetPointers) { Nd4jPointer *offsetPointers) {
try { try {
@ -1227,7 +1219,7 @@ void setGridLimit(int gridSize) {
// no-op // no-op
} }
sd::TadPack* tadOnlyShapeInfo(Nd4jLong *hXShapeInfo, int *dimension, int dimensionLength) { sd::TadPack* tadOnlyShapeInfo(Nd4jLong const* hXShapeInfo, int *dimension, int dimensionLength) {
auto pack = new TadPack(); auto pack = new TadPack();
try { try {
*pack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); *pack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
@ -1239,21 +1231,26 @@ sd::TadPack* tadOnlyShapeInfo(Nd4jLong *hXShapeInfo, int *dimension, int dimensi
return pack; return pack;
} }
Nd4jLong* getPrimaryShapeInfo(sd::TadPack* pack) { Nd4jLong const* getPrimaryShapeInfo(sd::TadPack* pack) {
return pack->primaryShapeInfo(); return const_cast<Nd4jLong*>(pack->primaryShapeInfo());
} }
Nd4jLong* getPrimaryOffsets(sd::TadPack* pack) {
return pack->primaryOffsets(); Nd4jLong const* getPrimaryOffsets(sd::TadPack* pack) {
return const_cast<Nd4jLong*>(pack->primaryOffsets());
} }
Nd4jLong* getSpecialShapeInfo(sd::TadPack* pack) {
return pack->specialShapeInfo(); Nd4jLong const* getSpecialShapeInfo(sd::TadPack* pack) {
return const_cast<Nd4jLong*>(pack->specialShapeInfo());
} }
Nd4jLong* getSpecialOffsets(sd::TadPack* pack) {
return pack->specialOffsets(); Nd4jLong const* getSpecialOffsets(sd::TadPack* pack) {
return const_cast<Nd4jLong*>(pack->specialOffsets());
} }
Nd4jLong getNumberOfTads(sd::TadPack* pack) { Nd4jLong getNumberOfTads(sd::TadPack* pack) {
return pack->numberOfTads(); return pack->numberOfTads();
} }
int getShapeInfoLength(sd::TadPack* pack) { int getShapeInfoLength(sd::TadPack* pack) {
return pack->shapeInfoLength(); return pack->shapeInfoLength();
} }
@ -1270,15 +1267,15 @@ Nd4jPointer getConstantSpace() {
template<typename T> template<typename T>
void pullRowsGeneric(void *vx, void pullRowsGeneric(void *vx,
Nd4jLong *hXShapeInfo, Nd4jLong const* hXShapeInfo,
void *vz, void *vz,
Nd4jLong *hZShapeInfo, Nd4jLong const* hZShapeInfo,
const int n, const int n,
Nd4jLong *indexes, Nd4jLong const* indexes,
Nd4jLong *tadShapeInfo, Nd4jLong const* tadShapeInfo,
Nd4jLong *tadOffsets, Nd4jLong const* tadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong const* zTadShapeInfo,
Nd4jLong *zTadOffsets) { Nd4jLong const* zTadOffsets) {
auto hX = reinterpret_cast<T *>(vx); auto hX = reinterpret_cast<T *>(vx);
auto hZ = reinterpret_cast<T *>(vz); auto hZ = reinterpret_cast<T *>(vz);
@ -1322,14 +1319,14 @@ void pullRowsGeneric(void *vx,
} }
void pullRows(Nd4jPointer *extraPointers, void pullRows(Nd4jPointer *extraPointers,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
Nd4jLong n, Nd4jLong n,
Nd4jLong *indexes, Nd4jLong* indexes,
Nd4jLong *tadShapeInfo, Nd4jLong const* tadShapeInfo,
Nd4jLong *tadOffsets, Nd4jLong const* tadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong const* zTadShapeInfo,
Nd4jLong *zTadOffsets) { Nd4jLong const* zTadOffsets) {
try { try {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -1342,11 +1339,11 @@ void pullRows(Nd4jPointer *extraPointers,
template<typename T> template<typename T>
void tearGeneric(void *vx, void tearGeneric(void *vx,
Nd4jLong *hXShapeInfo, Nd4jLong const* hXShapeInfo,
Nd4jPointer *targets, Nd4jPointer *targets,
Nd4jLong *hZShapeInfo, Nd4jLong const* hZShapeInfo,
Nd4jLong *tadShapeInfo, Nd4jLong const* tadShapeInfo,
Nd4jLong *tadOffsets) { Nd4jLong const* tadOffsets) {
auto hX = reinterpret_cast<T *>(vx); auto hX = reinterpret_cast<T *>(vx);
@ -1381,11 +1378,11 @@ void tearGeneric(void *vx,
} }
void tear(Nd4jPointer *extraPointers, void tear(Nd4jPointer *extraPointers,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
Nd4jPointer *targets, Nd4jPointer *targets,
Nd4jLong *hZShapeInfo, Nd4jLong const* hZShapeInfo,
Nd4jLong *tadShapeInfo, Nd4jLong const* tadShapeInfo,
Nd4jLong *tadOffsets) { Nd4jLong const* tadOffsets) {
try { try {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -1398,10 +1395,10 @@ void tear(Nd4jPointer *extraPointers,
void average(Nd4jPointer *extras, void average(Nd4jPointer *extras,
Nd4jPointer *hX, Nd4jLong *hXShapeInfo, Nd4jPointer *hX, const Nd4jLong *hXShapeInfo,
Nd4jPointer *dX, Nd4jLong *dXShapeInfo, Nd4jPointer *dX, const Nd4jLong *dXShapeInfo,
void *z, Nd4jLong *hZShapeInfo, void *z, const Nd4jLong *hZShapeInfo,
void *dz, Nd4jLong *dZShapeInfo, void *dz, const Nd4jLong *dZShapeInfo,
int n, int n,
Nd4jLong length, Nd4jLong length,
bool propagate) { bool propagate) {
@ -1416,10 +1413,10 @@ void average(Nd4jPointer *extras,
} }
void accumulate(Nd4jPointer *extras, void accumulate(Nd4jPointer *extras,
Nd4jPointer *hX, Nd4jLong *hXShapeInfo, Nd4jPointer *hX, Nd4jLong const* hXShapeInfo,
Nd4jPointer *dX, Nd4jLong *dXShapeInfo, Nd4jPointer *dX, Nd4jLong const* dXShapeInfo,
void *hz, Nd4jLong *hZShapeInfo, void *hz, Nd4jLong const* hZShapeInfo,
void *dz, Nd4jLong *dZShapeInfo, void *dz, Nd4jLong const* dZShapeInfo,
int n, int n,
Nd4jLong length) { Nd4jLong length) {
try { try {
@ -1436,6 +1433,28 @@ void enableP2P(bool enable) {
// no-op // no-op
} }
void encodeThresholdP1(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, Nd4jLong N, int *dz, float threshold) {
// TODO: to be implemented
}
void encodeThresholdP2Int(Nd4jPointer *extraPointers, int *hX, Nd4jLong N, int *dz) {
// TODO: to be implemented
}
void encodeThresholdP3(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, int *offsets, Nd4jLong N, int *dz){
// offsets won't be used here
// TODO: to be implemented
}
void decodeThreshold(Nd4jPointer *extraPointers, void *hX, Nd4jLong N, void *dz, const Nd4jLong *hZShapeInfo){
// TODO: to be implemented
}
bool isP2PAvailable() { bool isP2PAvailable() {
// always TRUE for cpu backend // always TRUE for cpu backend
return true; return true;
@ -1445,8 +1464,12 @@ void checkP2P() {
// no-op // no-op
} }
void decodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong N, void *dz, Nd4jLong const* hZShapeInfo) {
NativeOpExecutioner::decodeBitmap(hX, N, dz, hZShapeInfo);
}
template<typename T> template<typename T>
void shuffleGeneric(void **hX, Nd4jLong **hXShapeInfo, void **dz, Nd4jLong **hZShapeInfo, int N, int *shuffleMap, Nd4jLong **tadOnlyShapeInfo, Nd4jLong **tadOffsets) { void shuffleGeneric(void **hX, Nd4jLong * const*hXShapeInfo, void **dz, Nd4jLong * const* hZShapeInfo, int N, int *shuffleMap, Nd4jLong * const* tadOnlyShapeInfo, Nd4jLong * const* tadOffsets) {
auto dX = reinterpret_cast<T **>(hX); auto dX = reinterpret_cast<T **>(hX);
auto dZ = reinterpret_cast<T **>(dz); auto dZ = reinterpret_cast<T **>(dz);
@ -1517,10 +1540,10 @@ void shuffle(Nd4jPointer *extras,
Nd4jPointer *tadShapeInfo, Nd4jPointer *tadShapeInfo,
Nd4jPointer *tadOffsets) { Nd4jPointer *tadOffsets) {
try { try {
auto xShape = reinterpret_cast<Nd4jLong **>(hXShapeInfo); auto xShape = reinterpret_cast<Nd4jLong * const*>(hXShapeInfo);
auto zShape = reinterpret_cast<Nd4jLong **>(hZShapeInfo); auto zShape = reinterpret_cast<Nd4jLong * const*>(hZShapeInfo);
auto tadOnlyShapeInfo = reinterpret_cast<Nd4jLong **>(tadShapeInfo); auto tadOnlyShapeInfo = reinterpret_cast<Nd4jLong * const*>(tadShapeInfo);
auto tadOffset = reinterpret_cast<Nd4jLong **>(tadOffsets); auto tadOffset = reinterpret_cast<Nd4jLong * const*>(tadOffsets);
auto xType = sd::ArrayOptions::dataType(xShape[0]); auto xType = sd::ArrayOptions::dataType(xShape[0]);
@ -1548,13 +1571,13 @@ int getDevice() {
void execScalarTad(Nd4jPointer *extraPointers, void execScalarTad(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const*dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const*dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const*tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong const*tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
try { try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary()); auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape)); int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
@ -1588,13 +1611,13 @@ void execScalarTad(Nd4jPointer *extraPointers,
void execScalarBoolTad(Nd4jPointer *extraPointers, void execScalarBoolTad(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, OpaqueDataBuffer *dbScalars, const Nd4jLong *hScalarShapeInfo, const Nd4jLong *dScalarShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) {
try { try {
auto dimension = reinterpret_cast<int *>(dbDimension->primary()); auto dimension = reinterpret_cast<int *>(dbDimension->primary());
int dimensionLength = static_cast<int>(shape::length(hDimensionShape)); int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
@ -1696,7 +1719,7 @@ void execAggregateBatch(Nd4jPointer *extraPointers,
void execRandom(Nd4jPointer *extraPointers, void execRandom(Nd4jPointer *extraPointers,
int opNum, int opNum,
Nd4jPointer state, Nd4jPointer state,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraArguments) { void *extraArguments) {
try { try {
NativeOpExecutioner::execRandom(nullptr, opNum, state, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments); NativeOpExecutioner::execRandom(nullptr, opNum, state, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments);
@ -1709,9 +1732,9 @@ void execRandom(Nd4jPointer *extraPointers,
void execRandom3(Nd4jPointer *extraPointers, void execRandom3(Nd4jPointer *extraPointers,
int opNum, int opNum,
Nd4jPointer state, Nd4jPointer state,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraArguments) { void *extraArguments) {
try { try {
NativeOpExecutioner::execRandom(nullptr, opNum, state, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments); NativeOpExecutioner::execRandom(nullptr, opNum, state, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments);
@ -1724,8 +1747,8 @@ void execRandom3(Nd4jPointer *extraPointers,
void execRandom2(Nd4jPointer *extraPointers, void execRandom2(Nd4jPointer *extraPointers,
int opNum, int opNum,
Nd4jPointer state, Nd4jPointer state,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
void *extraArguments) { void *extraArguments) {
try { try {
NativeOpExecutioner::execRandom(nullptr, opNum, state, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments); NativeOpExecutioner::execRandom(nullptr, opNum, state, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments);
@ -1793,8 +1816,8 @@ Nd4jPointer pointerForAddress(Nd4jLong address) {
} }
void sort(Nd4jPointer *extraPointers, void sort(Nd4jPointer *extraPointers,
void *hX, Nd4jLong *hXShapeInfo, void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void *dX, const Nd4jLong *dXShapeInfo,
bool descending) { bool descending) {
try { try {
NativeOpExecutioner::execSort(hX, hXShapeInfo, descending); NativeOpExecutioner::execSort(hX, hXShapeInfo, descending);
@ -1805,12 +1828,11 @@ void sort(Nd4jPointer *extraPointers,
} }
void sortTad(Nd4jPointer *extraPointers, void sortTad(Nd4jPointer *extraPointers,
void *hX, Nd4jLong *hXShapeInfo, void *hX, const Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void *dX, const Nd4jLong *dXShapeInfo,
int *dimension, int *dimension, int dimensionLength,
int dimensionLength, const Nd4jLong *tadShapeInfo,
Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *tadOffsets,
bool descending) { bool descending) {
try { try {
NativeOpExecutioner::execSort(hX, hXShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending); NativeOpExecutioner::execSort(hX, hXShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending);
@ -1833,6 +1855,12 @@ void sortCooIndices(Nd4jPointer *extraPointers,
} }
} }
Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, Nd4jLong N, int *dz, float threshold) {
return NativeOpExecutioner::encodeBitmap(hX, hXShapeInfo, N, dz, threshold);
}
Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length) { Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length) {
auto hZ = new Nd4jLong[2];errno = 0; auto hZ = new Nd4jLong[2];errno = 0;
try { try {
@ -1916,7 +1944,7 @@ FORCEINLINE int estimateThresholdGeneric(Nd4jPointer *extraPointers, Nd4jPointer
} }
int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer hX, Nd4jLong *hXShapeInfo, int N, float threshold) { int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer hX, Nd4jLong const* hXShapeInfo, int N, float threshold) {
try { try {
auto xType = ArrayOptions::dataType(hXShapeInfo); auto xType = ArrayOptions::dataType(hXShapeInfo);
BUILD_SINGLE_SELECTOR(xType, return estimateThresholdGeneric, (extraPointers, hX, N, threshold), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(xType, return estimateThresholdGeneric, (extraPointers, hX, N, threshold), FLOAT_TYPES);
@ -1931,8 +1959,8 @@ Nd4jLong getShapeListSize(sd::ShapeList* list) {
return list->size(); return list->size();
} }
Nd4jLong* getShape(sd::ShapeList* list, Nd4jLong i) { Nd4jLong const* getShape(sd::ShapeList* list, Nd4jLong i) {
return list->at(i); return const_cast<Nd4jLong const*>(list->at(i));
} }
void deleteShapeList(Nd4jPointer shapeList) { void deleteShapeList(Nd4jPointer shapeList) {
@ -2226,8 +2254,8 @@ const char* getVariableName(sd::graph::Variable* variable) {
return variable->getName()->c_str(); return variable->getName()->c_str();
} }
Nd4jLong* getVariableShape(sd::graph::Variable* variable) { Nd4jLong const* getVariableShape(sd::graph::Variable* variable) {
return variable->getNDArray()->shapeInfo(); return const_cast<Nd4jLong const*>(variable->getNDArray()->shapeInfo());
} }
void* getVariableBuffer(sd::graph::Variable* variable) { void* getVariableBuffer(sd::graph::Variable* variable) {
@ -2569,12 +2597,13 @@ void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
} }
template <typename I> template <typename I>
static void _scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, static void _scatterUpdate(
void* hX, Nd4jLong* hXShapeInfo, Nd4jLong* hXOffsets, Nd4jPointer *extraPointers, int opCode, int numOfSubArrs,
void* dX, Nd4jLong* dXShapeInfo, Nd4jLong* dXOffsets, void* hX, const Nd4jLong* hXShapeInfo, const Nd4jLong* hXOffsets,
void* hY, Nd4jLong* hYShapeInfo, Nd4jLong* hYOffsets, void* dX, const Nd4jLong* dXShapeInfo, const Nd4jLong* dXOffsets,
void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets, void* hY, const Nd4jLong* hYShapeInfo, const Nd4jLong* hYOffsets,
void* vIindexes, Nd4jLong* hIndicesShapeInfo, void* dIindexes, Nd4jLong* dIndicesShapeInfo) { void* dY, const Nd4jLong* dYShapeInfo, const Nd4jLong* dYOffsets,
void* vIindexes, const Nd4jLong* hIndicesShapeInfo, void* dIindexes, const Nd4jLong* dIndicesShapeInfo) {
auto hIindexes = reinterpret_cast<I*>(vIindexes); auto hIindexes = reinterpret_cast<I*>(vIindexes);
auto func = PRAGMA_THREADS_DO { auto func = PRAGMA_THREADS_DO {
@ -2626,11 +2655,11 @@ static void _scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSub
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs,
void* hX, Nd4jLong* hXShapeInfo, Nd4jLong* hXOffsets, void* hX, const Nd4jLong* hXShapeInfo, const Nd4jLong* hXOffsets,
void* dX, Nd4jLong* dXShapeInfo, Nd4jLong* dXOffsets, void* dX, const Nd4jLong* dXShapeInfo, const Nd4jLong* dXOffsets,
void* hY, Nd4jLong* hYShapeInfo, Nd4jLong* hYOffsets, void* hY, const Nd4jLong* hYShapeInfo, const Nd4jLong* hYOffsets,
void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets, void* dY, const Nd4jLong* dYShapeInfo, const Nd4jLong* dYOffsets,
void* hIindexes, Nd4jLong* hIndicesShapeInfo, void* dIindexes, Nd4jLong* dIndicesShapeInfo) { void* hIindexes, const Nd4jLong* hIndicesShapeInfo, void* dIindexes, const Nd4jLong* dIndicesShapeInfo) {
auto iType = ArrayOptions::dataType(hIndicesShapeInfo); auto iType = ArrayOptions::dataType(hIndicesShapeInfo);
try { try {
@ -2686,7 +2715,7 @@ void deleteTadPack(sd::TadPack* ptr) {
delete ptr; delete ptr;
} }
sd::ConstantDataBuffer* constantBufferLong(sd::DataType dtype, Nd4jLong *data, int length) { sd::ConstantDataBuffer* constantBufferLong(sd::DataType dtype, const Nd4jLong *data, int length) {
return nullptr; return nullptr;
} }
@ -2847,7 +2876,7 @@ Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) {
} else { } else {
shapeBuffer = sd::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); shapeBuffer = sd::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
} }
return reinterpret_cast<Nd4jPointer>(sd::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true)); return const_cast<Nd4jLong*>(sd::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true));
} catch (std::exception &e) { } catch (std::exception &e) {
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
@ -2856,10 +2885,10 @@ Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) {
} }
void sortByKey(Nd4jPointer *extraPointers, void sortByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, void *x, const Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo, void *dx, const Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo, void *y, const Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo, void *dy, const Nd4jLong *dyShapeInfo,
bool descending) { bool descending) {
try { try {
auto xType = ArrayOptions::dataType(xShapeInfo); auto xType = ArrayOptions::dataType(xShapeInfo);
@ -2873,10 +2902,10 @@ void sortByKey(Nd4jPointer *extraPointers,
} }
void sortByValue(Nd4jPointer *extraPointers, void sortByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, void *x, const Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo, void *dx, const Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo, void *y, const Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo, void *dy, const Nd4jLong *dyShapeInfo,
bool descending) { bool descending) {
try { try {
auto xType = ArrayOptions::dataType(xShapeInfo); auto xType = ArrayOptions::dataType(xShapeInfo);
@ -2890,12 +2919,11 @@ void sortByValue(Nd4jPointer *extraPointers,
} }
void sortTadByKey(Nd4jPointer *extraPointers, void sortTadByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, void *x, const Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo, void *dx, const Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo, void *y, const Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo, void *dy, const Nd4jLong *dyShapeInfo,
int *dimension, int *dimension, int dimensionLength,
int dimensionLength,
bool descending) { bool descending) {
try { try {
auto xType = ArrayOptions::dataType(xShapeInfo); auto xType = ArrayOptions::dataType(xShapeInfo);
@ -2909,12 +2937,11 @@ void sortTadByKey(Nd4jPointer *extraPointers,
} }
void sortTadByValue(Nd4jPointer *extraPointers, void sortTadByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, void *x, const Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo, void *dx, const Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo, void *y, const Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo, void *dy, const Nd4jLong *dyShapeInfo,
int *dimension, int *dimension, int dimensionLength,
int dimensionLength,
bool descending) { bool descending) {
try { try {
auto xType = ArrayOptions::dataType(xShapeInfo); auto xType = ArrayOptions::dataType(xShapeInfo);
@ -3195,8 +3222,8 @@ void dbClose(OpaqueDataBuffer *dataBuffer) {
dataBuffer->getDataBuffer()->close(); dataBuffer->getDataBuffer()->close();
} }
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong const*, void*, Nd4jLong const*, const int, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong const* , Nd4jPointer*, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void shuffleGeneric, (void**, Nd4jLong**, void**, Nd4jLong**, int, int*, Nd4jLong**, Nd4jLong**), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void shuffleGeneric, (void**, Nd4jLong* const*, void**, Nd4jLong* const*, int, int*, Nd4jLong* const*, Nd4jLong* const*), LIBND4J_TYPES);

View File

@ -87,12 +87,12 @@ extern "C" __global__ void prepareShapeBuffer(int *dimension, int *maxDimension,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execPairwiseTransform(sd::LaunchContext *lc, void NativeOpExecutioner::execPairwiseTransform(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, void const* hY, Nd4jLong const* hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
void *extraParams) { void *extraParams) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -128,12 +128,12 @@ void NativeOpExecutioner::execPairwiseTransform(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execPairwiseBoolTransform( sd::LaunchContext *lc, void NativeOpExecutioner::execPairwiseBoolTransform( sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, void const* hY, Nd4jLong const* hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
void *extraParams) { void *extraParams) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -164,12 +164,12 @@ void NativeOpExecutioner::execPairwiseBoolTransform( sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execPairwiseIntTransform( sd::LaunchContext *lc, void NativeOpExecutioner::execPairwiseIntTransform( sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, void const* hY, Nd4jLong const* hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void * hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void * dZ, Nd4jLong const* dZShapeInfo,
void *extraParams) { void *extraParams) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -200,11 +200,11 @@ void NativeOpExecutioner::execPairwiseIntTransform( sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execSummaryStatsScalar(sd::LaunchContext *lc, void NativeOpExecutioner::execSummaryStatsScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
bool biasCorrected) { bool biasCorrected) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -226,16 +226,16 @@ void NativeOpExecutioner::execSummaryStatsScalar(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc, void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, void const* hY, Nd4jLong const* hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
void *extraParams, void *extraParams,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -300,16 +300,16 @@ void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext* lc, const int opN
void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc, void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, void const* hY, Nd4jLong const* hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
void *extraParams, void *extraParams,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -338,15 +338,15 @@ void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc, void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, void const* hY, Nd4jLong const* hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { Nd4jLong const* tadOnlyShapeInfoZ,Nd4jLong const* tadOffsetsZ) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -413,15 +413,15 @@ void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext* lc, const int opNu
void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc, void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, void const* hY, Nd4jLong const* hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { Nd4jLong const* tadOnlyShapeInfoZ,Nd4jLong const* tadOffsetsZ) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
auto xType = sd::ArrayOptions::dataType(hXShapeInfo); auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@ -465,15 +465,15 @@ void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc,
*/ */
void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc, void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, void const* hY, Nd4jLong const* hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { Nd4jLong const* tadOnlyShapeInfoZ,Nd4jLong const* tadOffsetsZ) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -536,15 +536,15 @@ void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc, const int opNum,
void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc, void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, void const* hY, Nd4jLong const* hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { Nd4jLong const* tadOnlyShapeInfoZ,Nd4jLong const* tadOffsetsZ) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -572,13 +572,13 @@ void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc, void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer(); auto reductionPointer = lc->getReductionPointer();
@ -607,13 +607,13 @@ void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc, void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension,int dimensionLength, int *dimension,int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer(); auto reductionPointer = lc->getReductionPointer();
@ -643,13 +643,13 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc, void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer(); auto reductionPointer = lc->getReductionPointer();
@ -689,13 +689,13 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
*/ */
void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc, void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer(); auto reductionPointer = lc->getReductionPointer();
@ -734,13 +734,13 @@ void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc,
*/ */
void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc, void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension,int dimensionLength, int *dimension,int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer(); auto reductionPointer = lc->getReductionPointer();
@ -774,11 +774,11 @@ void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc, void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo){ void *dZ, Nd4jLong const* dZShapeInfo){
if (sd::Environment::getInstance()->isDebug()) if (sd::Environment::getInstance()->isDebug())
printf("F1 opNum:[%i]\n", opNum); printf("F1 opNum:[%i]\n", opNum);
@ -825,11 +825,11 @@ void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceFloatScalar(sd::LaunchContext *lc, void NativeOpExecutioner::execReduceFloatScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, Nd4jLong const* dZShapeInfo) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer(); auto reductionPointer = lc->getReductionPointer();
@ -854,11 +854,11 @@ void NativeOpExecutioner::execReduceFloatScalar(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceBoolScalar(sd::LaunchContext *lc, void NativeOpExecutioner::execReduceBoolScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, Nd4jLong const* dZShapeInfo) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer(); auto reductionPointer = lc->getReductionPointer();
@ -885,11 +885,11 @@ void NativeOpExecutioner::execReduceBoolScalar(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceSameScalar(sd::LaunchContext *lc, void NativeOpExecutioner::execReduceSameScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, Nd4jLong const* dZShapeInfo) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer(); auto reductionPointer = lc->getReductionPointer();
@ -916,11 +916,11 @@ void NativeOpExecutioner::execReduceSameScalar(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduceLongScalar(sd::LaunchContext *lc, void NativeOpExecutioner::execReduceLongScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, Nd4jLong const* dZShapeInfo) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer(); auto reductionPointer = lc->getReductionPointer();
@ -947,12 +947,12 @@ void NativeOpExecutioner::execReduceLongScalar(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execTransformSame(sd::LaunchContext *lc, void NativeOpExecutioner::execTransformSame(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -981,12 +981,12 @@ void NativeOpExecutioner::execTransformSame(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execTransformBool(sd::LaunchContext *lc, void NativeOpExecutioner::execTransformBool(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -1015,12 +1015,12 @@ void NativeOpExecutioner::execTransformBool(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc, void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) { Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool allowParallelism) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -1050,12 +1050,12 @@ void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execTransformStrict(sd::LaunchContext *lc, void NativeOpExecutioner::execTransformStrict(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -1084,12 +1084,12 @@ void NativeOpExecutioner::execTransformStrict(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execTransformFloat(sd::LaunchContext *lc, void NativeOpExecutioner::execTransformFloat(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer(); auto reductionPointer = lc->getReductionPointer();
@ -1118,11 +1118,11 @@ void NativeOpExecutioner::execTransformFloat(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc, void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
bool biasCorrected) { bool biasCorrected) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -1147,13 +1147,13 @@ void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc, void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
bool biasCorrected) { bool biasCorrected) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer(); auto reductionPointer = lc->getReductionPointer();
@ -1178,13 +1178,13 @@ void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc, void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hY, Nd4jLong *hYShapeInfo, void const* hY, Nd4jLong const* hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, Nd4jLong const* dZShapeInfo) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer(); auto reductionPointer = lc->getReductionPointer();
@ -1215,16 +1215,16 @@ void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc, void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hY, Nd4jLong *hYShapeInfo, void const* hY, Nd4jLong const* hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong* tadOnlyShapeInfo, Nd4jLong* tadOffsets, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong* yTadOnlyShapeInfo, Nd4jLong* yTadOffsets) { Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) {
if(shape::isScalar(hZShapeInfo)) { if(shape::isScalar(hZShapeInfo)) {
NativeOpExecutioner::execReduce3(lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); NativeOpExecutioner::execReduce3(lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo);
@ -1268,13 +1268,13 @@ void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduce3Scalar(sd::LaunchContext *lc, void NativeOpExecutioner::execReduce3Scalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hY, Nd4jLong *hYShapeInfo, void const* hY, Nd4jLong const* hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, Nd4jLong const* dZShapeInfo) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -1308,12 +1308,12 @@ void NativeOpExecutioner::execReduce3Scalar(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc, void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
void *hScalar, Nd4jLong *hScalarShapeInfo, void const* hScalar, Nd4jLong const* hScalarShapeInfo,
void *dScalar, Nd4jLong *dScalarShapeInfo, void const* dScalar, Nd4jLong const* dScalarShapeInfo,
void *extraParams, bool allowParallelism) { void *extraParams, bool allowParallelism) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -1344,16 +1344,16 @@ void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc, void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo, void const* hScalars, Nd4jLong const* hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo, void const* dScalars, Nd4jLong const* dScalarShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -1383,12 +1383,12 @@ void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc, void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
void *hScalar, Nd4jLong *hScalarShapeInfo, void const* hScalar, Nd4jLong const* hScalarShapeInfo,
void *dScalar, Nd4jLong *dScalarShapeInfo, void const* dScalar, Nd4jLong const* dScalarShapeInfo,
void *extraParams, bool allowParallelism) { void *extraParams, bool allowParallelism) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -1419,16 +1419,16 @@ void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc, void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo, void const* hScalars, Nd4jLong const* hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo, void const* dScalars, Nd4jLong const* dScalarShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -1458,12 +1458,12 @@ void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalar(sd::LaunchContext *lc, void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void* dZ, Nd4jLong const* dZShapeInfo,
void *hScalar, Nd4jLong *hScalarShapeInfo, void const* hScalar, Nd4jLong const* hScalarShapeInfo,
void *dScalar, Nd4jLong *dScalarShapeInfo, void const* dScalar, Nd4jLong const* dScalarShapeInfo,
void *extraParams, bool allowParallelism) { void *extraParams, bool allowParallelism) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -1493,16 +1493,16 @@ void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalar(sd::LaunchContext *lc, void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo, void const* hScalars, Nd4jLong const* hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo, void const* dScalars, Nd4jLong const* dScalarShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -1531,8 +1531,8 @@ void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
int opNum, int opNum,
Nd4jPointer stateHost, Nd4jPointer stateHost,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
void *extraArguments) { void *extraArguments) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -1564,10 +1564,10 @@ void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
int opNum, int opNum,
Nd4jPointer stateHost, Nd4jPointer stateHost,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
void *extraArguments) { void *extraArguments) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -1599,12 +1599,12 @@ void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
int opNum, int opNum,
Nd4jPointer stateHost, Nd4jPointer stateHost,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, void const* hY, Nd4jLong const* hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
void *extraArguments) { void *extraArguments) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
@ -1634,16 +1634,16 @@ void NativeOpExecutioner::execRandom(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc, void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo, void const* hY, Nd4jLong const* hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong const* xTadShapeInfo, Nd4jLong const* xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) { Nd4jLong const* yTadShapeInfo, Nd4jLong const* yOffsets) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
auto allocationPointer = lc->getAllocationPointer(); auto allocationPointer = lc->getAllocationPointer();
@ -1676,16 +1676,16 @@ void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execReduce3TAD(sd::LaunchContext *lc, void NativeOpExecutioner::execReduce3TAD(sd::LaunchContext *lc,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void const* hX, Nd4jLong const* hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
void *hY, Nd4jLong *hYShapeInfo, void const* hY, Nd4jLong const* hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, void const* dY, Nd4jLong const* dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong const* dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yTadOffsets) { Nd4jLong const* yTadShapeInfo, Nd4jLong const* yTadOffsets) {
if(shape::isScalar(hZShapeInfo)) { if(shape::isScalar(hZShapeInfo)) {
NativeOpExecutioner::execReduce3(lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); NativeOpExecutioner::execReduce3(lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo);

View File

@ -123,8 +123,8 @@ int getDeviceSharedThreshold(int deviceId) {
sd::buffer::Buffer<Nd4jLong> * createScalarBuffer(cudaStream_t stream) { sd::buffer::Buffer<Nd4jLong> * createScalarBuffer(cudaStream_t stream) {
Nd4jLong *scalarShapeInfo = shape::createScalarShapeInfo(); auto scalarShapeInfo = shape::createScalarShapeInfo();
sd::buffer::Buffer<Nd4jLong> *buff = sd::buffer::createBuffer(scalarShapeInfo,shape::shapeInfoLength(2), stream); auto buff = sd::buffer::createBuffer(scalarShapeInfo,shape::shapeInfoLength(2), stream);
sd::buffer::copyDataToGpu(&buff, stream); sd::buffer::copyDataToGpu(&buff, stream);
return buff; return buff;
} }
@ -229,9 +229,9 @@ public:
void execPairwiseTransform( Nd4jPointer *extraPointers, void execPairwiseTransform( Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams) { void *extraParams) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY});
@ -251,9 +251,9 @@ void execPairwiseTransform( Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execPairwiseTransformBool(Nd4jPointer *extraPointers, void execPairwiseTransformBool(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams) { void *extraParams) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY});
@ -275,9 +275,9 @@ void execPairwiseTransformBool(Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execSummaryStatsScalar(Nd4jPointer *extraPointers, void execSummaryStatsScalar(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
bool biasCorrected) { bool biasCorrected) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -299,11 +299,11 @@ void execSummaryStatsScalar(Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execBroadcastBool(Nd4jPointer *extraPointers, void execBroadcastBool(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -348,10 +348,10 @@ void execBroadcastBool(Nd4jPointer *extraPointers,
void execBroadcast( void execBroadcast(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -399,9 +399,9 @@ void execBroadcast(
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execReduceFloat(Nd4jPointer *extraPointers, void execReduceFloat(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -421,9 +421,9 @@ void execReduceFloat(Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execReduceSame(Nd4jPointer *extraPointers, void execReduceSame(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -443,10 +443,10 @@ void execReduceSame(Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execReduceSame2(Nd4jPointer *extraPointers, void execReduceSame2(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const*hXShapeInfo, Nd4jLong const*dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const*hZShapeInfo, Nd4jLong const*dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { OpaqueDataBuffer *dbDimension, Nd4jLong const*hDimensionShape, Nd4jLong const*dDimensionShape) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -476,10 +476,10 @@ void execReduceSame2(Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execReduceLong2(Nd4jPointer *extraPointers, void execReduceLong2(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const*hXShapeInfo, Nd4jLong const*dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const*hZShapeInfo, Nd4jLong const*dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { OpaqueDataBuffer *dbDimension, Nd4jLong const*hDimensionShape, Nd4jLong const*dDimensionShape) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -509,9 +509,9 @@ void execReduceLong2(Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execReduceLong(Nd4jPointer *extraPointers, void execReduceLong(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const*hXShapeInfo, Nd4jLong const*dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { OpaqueDataBuffer *dbZ, Nd4jLong const*hZShapeInfo, Nd4jLong const*dZShapeInfo) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -551,10 +551,10 @@ void execReduceLong(Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execReduceBool2(Nd4jPointer *extraPointers, void execReduceBool2(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const*hXShapeInfo, Nd4jLong const*dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const*hZShapeInfo, Nd4jLong const*dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { OpaqueDataBuffer *dbDimension, Nd4jLong const*hDimensionShape, Nd4jLong const*dDimensionShape) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -584,9 +584,9 @@ void execReduceBool2(Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execReduceBool(Nd4jPointer *extraPointers, void execReduceBool(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -637,10 +637,10 @@ void execReduceBool(Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execIndexReduce(Nd4jPointer *extraPointers, void execIndexReduce(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -679,10 +679,10 @@ void execIndexReduce(Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execReduceFloat2(Nd4jPointer *extraPointers, void execReduceFloat2(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -720,9 +720,9 @@ void execReduceFloat2(Nd4jPointer *extraPointers,
void execIndexReduceScalar( void execIndexReduceScalar(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo){ OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo){
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -741,8 +741,8 @@ void execIndexReduceScalar(
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execTransformSame(Nd4jPointer *extraPointers,int opNum, void execTransformSame(Nd4jPointer *extraPointers,int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams) { void *extraParams) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -766,8 +766,8 @@ void execTransformSame(Nd4jPointer *extraPointers,int opNum,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execTransformBool(Nd4jPointer *extraPointers,int opNum, void execTransformBool(Nd4jPointer *extraPointers,int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams) { void *extraParams) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -791,8 +791,8 @@ void execTransformBool(Nd4jPointer *extraPointers,int opNum,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execTransformAny(Nd4jPointer *extraPointers,int opNum, void execTransformAny(Nd4jPointer *extraPointers,int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams) { void *extraParams) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -817,8 +817,8 @@ void execTransformAny(Nd4jPointer *extraPointers,int opNum,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execTransformStrict(Nd4jPointer *extraPointers,int opNum, void execTransformStrict(Nd4jPointer *extraPointers,int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams) { void *extraParams) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -842,8 +842,8 @@ void execTransformStrict(Nd4jPointer *extraPointers,int opNum,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execTransformFloat(Nd4jPointer *extraPointers,int opNum, void execTransformFloat(Nd4jPointer *extraPointers,int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraParams) { void *extraParams) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -1368,7 +1368,7 @@ void specialConcat(
Nd4jPointer *data, Nd4jPointer *data,
Nd4jPointer *inputShapeInfo, Nd4jPointer *inputShapeInfo,
void *dZ, void *dZ,
Nd4jLong *dZShapeInfo, Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers) { Nd4jLong const* dZShapeInfo, Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers) {
try { try {
BUILD_SINGLE_SELECTOR(ArrayOptions::dataType(dZShapeInfo), sd::SpecialMethods, BUILD_SINGLE_SELECTOR(ArrayOptions::dataType(dZShapeInfo), sd::SpecialMethods,
::concatCpuGeneric(dimension, numArrays, data, inputShapeInfo, dZ, dZShapeInfo), ::concatCpuGeneric(dimension, numArrays, data, inputShapeInfo, dZ, dZShapeInfo),
@ -1383,7 +1383,7 @@ void specialConcat(
/** /**
* This method saves * This method saves
*/ */
sd::TadPack* tadOnlyShapeInfo(Nd4jLong *dXShapeInfo, int *dimension, int dimensionLength) { sd::TadPack* tadOnlyShapeInfo(Nd4jLong const* dXShapeInfo, int *dimension, int dimensionLength) {
try { try {
auto pack = new TadPack(); auto pack = new TadPack();
*pack = sd::ConstantTadHelper::getInstance()->tadForDimensions(dXShapeInfo, dimension, dimensionLength); *pack = sd::ConstantTadHelper::getInstance()->tadForDimensions(dXShapeInfo, dimension, dimensionLength);
@ -1395,16 +1395,16 @@ sd::TadPack* tadOnlyShapeInfo(Nd4jLong *dXShapeInfo, int *dimension, int dimensi
} }
} }
Nd4jLong* getPrimaryShapeInfo(sd::TadPack* pack) { Nd4jLong const* getPrimaryShapeInfo(sd::TadPack* pack) {
return pack->primaryShapeInfo(); return pack->primaryShapeInfo();
} }
Nd4jLong* getPrimaryOffsets(sd::TadPack* pack) { Nd4jLong const* getPrimaryOffsets(sd::TadPack* pack) {
return pack->primaryOffsets(); return pack->primaryOffsets();
} }
Nd4jLong* getSpecialShapeInfo(sd::TadPack* pack) { Nd4jLong const* getSpecialShapeInfo(sd::TadPack* pack) {
return pack->specialShapeInfo(); return pack->specialShapeInfo();
} }
Nd4jLong* getSpecialOffsets(sd::TadPack* pack) { Nd4jLong const* getSpecialOffsets(sd::TadPack* pack) {
return pack->specialOffsets(); return pack->specialOffsets();
} }
Nd4jLong getNumberOfTads(sd::TadPack* pack) { Nd4jLong getNumberOfTads(sd::TadPack* pack) {
@ -1460,14 +1460,14 @@ Nd4jPointer getConstantSpace() {
} }
void pullRows(Nd4jPointer *extraPointers, void pullRows(Nd4jPointer *extraPointers,
OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* xShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *zShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* zShapeInfo, Nd4jLong const* dZShapeInfo,
Nd4jLong n, Nd4jLong n,
Nd4jLong *indexes, Nd4jLong *indexes,
Nd4jLong *tadShapeInfo, Nd4jLong const* tadShapeInfo,
Nd4jLong *tadOffsets, Nd4jLong const* tadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong const* zTadShapeInfo,
Nd4jLong *zTadOffsets) { Nd4jLong const* zTadOffsets) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -1489,10 +1489,10 @@ void pullRows(Nd4jPointer *extraPointers,
void average(Nd4jPointer *extras, void average(Nd4jPointer *extras,
Nd4jPointer *x, Nd4jLong *xShapeInfo, Nd4jPointer *x, Nd4jLong const* xShapeInfo,
Nd4jPointer *dx, Nd4jLong *dXShapeInfo, Nd4jPointer *dx, Nd4jLong const* dXShapeInfo,
void *z, Nd4jLong *zShapeInfo, void *z, Nd4jLong const* zShapeInfo,
void *dz, Nd4jLong *dzShapeInfo, void *dz, Nd4jLong const* dzShapeInfo,
int n, int n,
Nd4jLong length, Nd4jLong length,
bool propagate) { bool propagate) {
@ -1524,10 +1524,10 @@ void average(Nd4jPointer *extras,
} }
void accumulate(Nd4jPointer *extras, void accumulate(Nd4jPointer *extras,
Nd4jPointer *x, Nd4jLong *xShapeInfo, Nd4jPointer *x, Nd4jLong const* xShapeInfo,
Nd4jPointer *dx, Nd4jLong *dXShapeInfo, Nd4jPointer *dx, Nd4jLong const* dXShapeInfo,
void *z, Nd4jLong *zShapeInfo, void *z, Nd4jLong const* zShapeInfo,
void *dz, Nd4jLong *dzShapeInfo, void *dz, Nd4jLong const* dzShapeInfo,
int n, int n,
Nd4jLong length) { Nd4jLong length) {
try { try {
@ -1572,8 +1572,8 @@ void shuffle(Nd4jPointer *extras,
auto dX = reinterpret_cast<void **>(dx); auto dX = reinterpret_cast<void **>(dx);
auto dZ = reinterpret_cast<void **>(dz); auto dZ = reinterpret_cast<void **>(dz);
auto xShape = reinterpret_cast<Nd4jLong **>(xShapeInfo); auto xShape = reinterpret_cast<Nd4jLong**>(xShapeInfo);
auto dxShape = reinterpret_cast<Nd4jLong **>(dXShapeInfo); auto dxShape = reinterpret_cast<Nd4jLong**>(dXShapeInfo);
auto tadOnlyShapeInfo = reinterpret_cast<Nd4jLong **>(tadShapeInfo); auto tadOnlyShapeInfo = reinterpret_cast<Nd4jLong **>(tadShapeInfo);
auto tadOffset = reinterpret_cast<Nd4jLong **>(tadOffsets); auto tadOffset = reinterpret_cast<Nd4jLong **>(tadOffsets);
@ -1614,9 +1614,9 @@ void setTADThreshold(int num) {
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execSummaryStats(Nd4jPointer *extraPointers, void execSummaryStats(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
bool biasCorrected) { bool biasCorrected) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -1638,12 +1638,12 @@ void execSummaryStats(Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execSummaryStatsTad(Nd4jPointer *extraPointers, void execSummaryStatsTad(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
bool biasCorrected, bool biasCorrected,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbDimension}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbDimension});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -1670,10 +1670,10 @@ void execSummaryStatsTad(Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execReduce3(Nd4jPointer *extraPointers, void execReduce3(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY});
@ -1694,13 +1694,13 @@ void execReduce3(Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execReduce3Tad(Nd4jPointer *extraPointers, void execReduce3Tad(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets) { Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -1744,10 +1744,10 @@ void execReduce3Tad(Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum, void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY});
@ -1768,9 +1768,9 @@ void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execScalarBool(Nd4jPointer *extraPointers, void execScalarBool(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbScalar, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, OpaqueDataBuffer *dbScalar, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo,
void *extraParams) { void *extraParams) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalar}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalar});
@ -1792,13 +1792,13 @@ void execScalarBool(Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execScalarBoolTad(Nd4jPointer *extraPointers, void execScalarBoolTad(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalars}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalars});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -1825,9 +1825,9 @@ void execScalarBoolTad(Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execScalar(Nd4jPointer *extraPointers, void execScalar(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbScalar, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, OpaqueDataBuffer *dbScalar, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo,
void *extraParams) { void *extraParams) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalar}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalar});
@ -1849,13 +1849,13 @@ void execScalar(Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execScalarTad(Nd4jPointer *extraPointers, void execScalarTad(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo,
void *extraParams, void *extraParams,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalars}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalars});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -1931,7 +1931,7 @@ void execAggregateBatch(Nd4jPointer *extraPointers,
void execRandom(Nd4jPointer *extraPointers, void execRandom(Nd4jPointer *extraPointers,
int opNum, int opNum,
Nd4jPointer stateHost, Nd4jPointer stateHost,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraArguments) { void *extraArguments) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {}); InteropDataBuffer::prepareSpecialUse({dbZ}, {});
@ -1950,8 +1950,8 @@ void execRandom(Nd4jPointer *extraPointers,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost, void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraArguments) { void *extraArguments) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX});
@ -1971,9 +1971,9 @@ void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execRandom3(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost, void execRandom3(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
void *extraArguments) { void *extraArguments) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY});
@ -2091,11 +2091,11 @@ Nd4jPointer pointerForAddress(Nd4jLong address) {
} }
void tear(Nd4jPointer *extras, void tear(Nd4jPointer *extras,
OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* xShapeInfo, Nd4jLong const* dXShapeInfo,
Nd4jPointer *targets, Nd4jPointer *targets,
Nd4jLong *zShapeInfo, Nd4jLong const* zShapeInfo,
Nd4jLong *tadShapeInfo, Nd4jLong const* tadShapeInfo,
Nd4jLong *tadOffsets) { Nd4jLong const* tadOffsets) {
try { try {
InteropDataBuffer::prepareSpecialUse({}, {dbX}); InteropDataBuffer::prepareSpecialUse({}, {dbX});
@ -2200,13 +2200,13 @@ void prescanArrayRecursive(Nd4jPointer *extras, int *dZ, int *dX, int numElement
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execReduce3All(Nd4jPointer *extraPointers, void execReduce3All(Nd4jPointer *extraPointers,
int opNum, int opNum,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong const* xTadShapeInfo, Nd4jLong const* xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) { Nd4jLong const* yTadShapeInfo, Nd4jLong const* yOffsets) {
try { try {
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY, dbDimension}); InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY, dbDimension});
InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); InteropDataBuffer::preparePrimaryUse({}, {dbDimension});
@ -2232,8 +2232,8 @@ void execReduce3All(Nd4jPointer *extraPointers,
void sort(Nd4jPointer *extraPointers, void sort(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, void *x, Nd4jLong const* xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void *dX, Nd4jLong const* dXShapeInfo,
bool descending) { bool descending) {
try { try {
cudaStream_t *stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]); cudaStream_t *stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
@ -2298,10 +2298,10 @@ void sort(Nd4jPointer *extraPointers,
void sortByKey(Nd4jPointer *extraPointers, void sortByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, void *x, Nd4jLong const* xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void *dX, Nd4jLong const* dXShapeInfo,
void *y, Nd4jLong *yShapeInfo, void *y, Nd4jLong const* yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo, void *dy, Nd4jLong const* dyShapeInfo,
bool descending) { bool descending) {
try { try {
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]); auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
@ -2372,10 +2372,10 @@ void sortByKey(Nd4jPointer *extraPointers,
} }
void sortByValue(Nd4jPointer *extraPointers, void sortByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, void *x, Nd4jLong const* xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void *dX, Nd4jLong const* dXShapeInfo,
void *y, Nd4jLong *yShapeInfo, void *y, Nd4jLong const* yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo, void *dy, Nd4jLong const* dyShapeInfo,
bool descending) { bool descending) {
try { try {
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]); auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
@ -2447,10 +2447,10 @@ void sortByValue(Nd4jPointer *extraPointers,
void sortTadByKey(Nd4jPointer *extraPointers, void sortTadByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, void *x, Nd4jLong const* xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void *dX, Nd4jLong const* dXShapeInfo,
void *y, Nd4jLong *yShapeInfo, void *y, Nd4jLong const* yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo, void *dy, Nd4jLong const* dyShapeInfo,
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
bool descending) { bool descending) {
@ -2474,10 +2474,10 @@ void sortTadByKey(Nd4jPointer *extraPointers,
} }
void sortTadByValue(Nd4jPointer *extraPointers, void sortTadByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, void *x, Nd4jLong const* xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void *dX, Nd4jLong const* dXShapeInfo,
void *y, Nd4jLong *yShapeInfo, void *y, Nd4jLong const* yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo, void *dy, Nd4jLong const* dyShapeInfo,
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
bool descending) { bool descending) {
@ -2503,12 +2503,12 @@ void sortTadByValue(Nd4jPointer *extraPointers,
void sortTad(Nd4jPointer *extraPointers, void sortTad(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, void *x, Nd4jLong const* xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void *dX, Nd4jLong const* dXShapeInfo,
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong const* tadShapeInfo,
Nd4jLong *tadOffsets, Nd4jLong const* tadOffsets,
bool descending) { bool descending) {
try { try {
// to be implemented // to be implemented
@ -2653,7 +2653,7 @@ Nd4jLong getShapeListSize(sd::ShapeList* list) {
return list->size(); return list->size();
} }
Nd4jLong* getShape(sd::ShapeList* list, Nd4jLong i) { Nd4jLong const* getShape(sd::ShapeList* list, Nd4jLong i) {
return list->at(i); return list->at(i);
} }
@ -2877,7 +2877,7 @@ const char* getVariableName(sd::graph::Variable* variable) {
return variable->getName()->c_str(); return variable->getName()->c_str();
} }
Nd4jLong* getVariableShape(sd::graph::Variable* variable) { Nd4jLong const* getVariableShape(sd::graph::Variable* variable) {
return variable->getNDArray()->shapeInfo(); return variable->getNDArray()->shapeInfo();
} }
@ -3026,7 +3026,7 @@ void deleteResultWrapper(Nd4jPointer ptr) {
delete p; delete p;
} }
int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer dX, Nd4jLong *dXShapeInfo, int N, float threshold) { int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer dX, Nd4jLong const* dXShapeInfo, int N, float threshold) {
throw std::runtime_error("estimateThreshold: Not implemented yet"); throw std::runtime_error("estimateThreshold: Not implemented yet");
} }
@ -3237,7 +3237,7 @@ void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
template<typename T, typename I> template<typename T, typename I>
__global__ static void scatterUpdateCuda(const int opCode, const int numOfSubArrs, __global__ static void scatterUpdateCuda(const int opCode, const int numOfSubArrs,
void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong *xOffsets,
void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets,
const void* vindexes) { const void* vindexes) {
@ -3300,7 +3300,7 @@ __global__ static void scatterUpdateCuda(const int opCode, const int numOfSubArr
} }
template<typename T, typename I> template<typename T, typename I>
__host__ static void scatterUpdateCudaLauncher(const cudaStream_t* stream, const int opCode, const int numOfSubArrs, void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, const void* indexes) { __host__ static void scatterUpdateCudaLauncher(const cudaStream_t* stream, const int opCode, const int numOfSubArrs, void* vx, const Nd4jLong const* xShapeInfo, const Nd4jLong* xOffsets, void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, const void* indexes) {
scatterUpdateCuda<T, I><<<512, 256, MAX_NUM_THREADS, *stream>>>(opCode, numOfSubArrs, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, indexes); scatterUpdateCuda<T, I><<<512, 256, MAX_NUM_THREADS, *stream>>>(opCode, numOfSubArrs, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, indexes);
} }
@ -3308,11 +3308,11 @@ __host__ static void scatterUpdateCudaLauncher(const cudaStream_t* stream, const
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs,
void* hX, Nd4jLong* hXShapeInfo, Nd4jLong* hXOffsets, void* hX, Nd4jLong const* hXShapeInfo, Nd4jLong const* hXOffsets,
void* dX, Nd4jLong* dXShapeInfo, Nd4jLong* dXOffsets, void* dX, Nd4jLong const* dXShapeInfo, Nd4jLong const* dXOffsets,
void* hY, Nd4jLong* hYShapeInfo, Nd4jLong* hYOffsets, void* hY, Nd4jLong const* hYShapeInfo, Nd4jLong const* hYOffsets,
void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets, void* dY, Nd4jLong const* dYShapeInfo, Nd4jLong const* dYOffsets,
void* hIindexes, Nd4jLong* hIndicesShapeInfo, void* dIindexes, Nd4jLong* dIndicesShapeInfo) { void* hIindexes, Nd4jLong const* hIndicesShapeInfo, void* dIindexes, Nd4jLong const* dIndicesShapeInfo) {
try { try {
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]); auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
@ -3409,7 +3409,7 @@ bool isBlasVersionMatches(int major, int minor, int build) {
return result; return result;
} }
sd::ConstantDataBuffer* constantBufferLong(sd::DataType dtype, Nd4jLong *data, int length) { sd::ConstantDataBuffer* constantBufferLong(sd::DataType dtype, Nd4jLong const* data, int length) {
return sd::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype); return sd::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype);
} }
@ -3555,8 +3555,7 @@ Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) {
} else { } else {
shapeBuffer = sd::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); shapeBuffer = sd::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
} }
return reinterpret_cast<Nd4jPointer>(sd::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, return (Nd4jPointer)(sd::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true)); // TO DO: this can lead to unpleasant crash sometimes
true));
} catch (std::exception &e) { } catch (std::exception &e) {
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());

View File

@ -21,6 +21,7 @@
#define DEV_TESTS_BROADCASTSCALARCONVERTER_H #define DEV_TESTS_BROADCASTSCALARCONVERTER_H
#include <system/op_boilerplate.h> #include <system/op_boilerplate.h>
#include <system/op_enums.h>
#include <stdexcept> #include <stdexcept>
namespace sd { namespace sd {

View File

@ -56,18 +56,15 @@ namespace functions {
class Broadcast { class Broadcast {
public: public:
#ifdef __CUDACC__ #ifdef __CUDABLAS__
template<typename OpType> template<typename OpType>
static __device__ void transformCuda( static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo,
void *x, const void *y, const Nd4jLong *yShapeInfo,
Nd4jLong *xShapeInfo, void *result, const Nd4jLong *resultShapeInfo,
void *y, int *dimension, int dimensionLength,
Nd4jLong *yShapeInfo, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
void *result, const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template<typename OpType> template<typename OpType>
static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo,
@ -75,67 +72,83 @@ namespace functions {
void *z, const Nd4jLong *zShapeInfo); void *z, const Nd4jLong *zShapeInfo);
template <typename OpClass> template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
template <typename OpClass> template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo); static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream,
int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo); static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream,
int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo);
template<typename OpType> template<typename OpType>
static __device__ void transformInverseCuda( static __device__ void transformInverseCuda(const void *x, const Nd4jLong *xShapeInfo,
void *x, const void *y, const Nd4jLong *yShapeInfo,
Nd4jLong *xShapeInfo, void *result, const Nd4jLong *resultShapeInfo,
void *y, int *dimension, int dimensionLength,
Nd4jLong *yShapeInfo, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
void *result, const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template <typename OpClass> template <typename OpClass>
static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream,
int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
#else #else
static void execInverse(int opNum, static void execInverse(int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *y, void *result, const Nd4jLong *resultShapeInfo,
Nd4jLong *yShapeInfo, int *dimension, int dimensionLength,
void *result, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
Nd4jLong *resultShapeInfo, const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
int *dimension, uint64_t start, uint64_t stop);
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
static void exec(int opNum, static void exec(int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *y, void *result, const Nd4jLong *resultShapeInfo,
Nd4jLong *yShapeInfo, int *dimension, int dimensionLength,
void *result, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
Nd4jLong *resultShapeInfo, const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
sd::LoopKind::Kind loopKind, sd::LoopKind::Kind loopKind,
uint64_t start, uint64_t start, uint64_t stop);
uint64_t stop);
/** /**
* CPU execution * CPU execution
@ -149,39 +162,25 @@ namespace functions {
* @param dimensionLength the length of the dimension buffer * @param dimensionLength the length of the dimension buffer
*/ */
template<typename OpType> template<typename OpType>
static void exec(void *x, static void exec(const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *y, void *result, const Nd4jLong *resultShapeInfo,
Nd4jLong *yShapeInfo, int *dimension, int dimensionLength,
void *result, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
Nd4jLong *resultShapeInfo, const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
sd::LoopKind::Kind loopKind, sd::LoopKind::Kind loopKind,
uint64_t start, uint64_t start, uint64_t stop);
uint64_t stop);
template<typename OpType> template<typename OpType>
static void execInverse(void *x, static void execInverse(const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *y, void *result, const Nd4jLong *resultShapeInfo,
Nd4jLong *yShapeInfo, int *dimension, int dimensionLength,
void *result, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
Nd4jLong *resultShapeInfo, const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
int *dimension, uint64_t start, uint64_t stop);
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
static void exec(const int opNum, static void exec(int opNum,
const void *x, const Nd4jLong *xShapeInfo, const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo); void *z, const Nd4jLong *zShapeInfo);

View File

@ -58,16 +58,13 @@ namespace functions {
#ifdef __CUDACC__ #ifdef __CUDACC__
template<typename OpType> template<typename OpType>
static __device__ void transformCuda( static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo,
void *x, const void *y, const Nd4jLong *yShapeInfo,
Nd4jLong *xShapeInfo, void *result, const Nd4jLong *resultShapeInfo,
void *y, void *extraParams,
Nd4jLong *yShapeInfo, int *dimension, int dimensionLength,
void *result, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *resultShapeInfo, const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
void *extraParams,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template<typename OpType> template<typename OpType>
static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo,
@ -76,7 +73,7 @@ namespace functions {
void *extraParams); void *extraParams);
template <typename OpClass> template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *result, Nd4jLong const* resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ);
template <typename OpClass> template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream,
@ -85,7 +82,7 @@ namespace functions {
void *z, const Nd4jLong *zShapeInfo, void *z, const Nd4jLong *zShapeInfo,
void *extraParams); void *extraParams);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *result, Nd4jLong const* resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum,
const void *x, const Nd4jLong *xShapeInfo, const void *x, const Nd4jLong *xShapeInfo,
@ -94,63 +91,61 @@ namespace functions {
void *extraParams); void *extraParams);
template<typename OpType> template<typename OpType>
static __device__ void transformInverseCuda( static __device__ void transformInverseCuda(const void *x, const Nd4jLong *xShapeInfo,
void *x, const void *y, const Nd4jLong *yShapeInfo,
Nd4jLong *xShapeInfo, void *result, const Nd4jLong *resultShapeInfo,
void *y, void *extraParams,
Nd4jLong *yShapeInfo, int *dimension, int dimensionLength,
void *result, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
Nd4jLong *resultShapeInfo, const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
void *extraParams,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template <typename OpClass> template <typename OpClass>
static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream,
int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
#else #else
static void exec(int opNum, static void exec(int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *y, void *result, const Nd4jLong *resultShapeInfo,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
void *extraParams, void *extraParams,
int *dimension, int *dimension, int dimensionLength,
int dimensionLength, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfo, const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
Nd4jLong *tadOffset, uint64_t start, uint64_t stop);
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
static void exec(const int opNum, static void exec(int opNum,
const void *x, const Nd4jLong *xShapeInfo, const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo, void *z, const Nd4jLong *zShapeInfo,
void *extraParams); void *extraParams);
static void execInverse(int opNum, static void execInverse(int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *y, void *result, const Nd4jLong *resultShapeInfo,
Nd4jLong *yShapeInfo, void *extraParams,
void *result, int *dimension, int dimensionLength,
Nd4jLong *resultShapeInfo, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
void *extraParams, const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
int *dimension, uint64_t start, uint64_t stop);
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
/** /**
* CPU execution * CPU execution
@ -164,21 +159,14 @@ namespace functions {
* @param dimensionLength the length of the dimension buffer * @param dimensionLength the length of the dimension buffer
*/ */
template<typename OpType> template<typename OpType>
static void exec(void *x, static void exec(const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *y, void *result, const Nd4jLong *resultShapeInfo,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
void *extraParams, void *extraParams,
int *dimension, int *dimension, int dimensionLength,
int dimensionLength, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfo, const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
Nd4jLong *tadOffset, uint64_t start, uint64_t stop);
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
template<typename OpType> template<typename OpType>
static void exec(const void *x, const Nd4jLong *xShapeInfo, static void exec(const void *x, const Nd4jLong *xShapeInfo,
@ -187,21 +175,14 @@ namespace functions {
void *extraParams); void *extraParams);
template<typename OpType> template<typename OpType>
static void execInverse(void *x, static void execInverse(const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *y, void *result, const Nd4jLong *resultShapeInfo,
Nd4jLong *yShapeInfo, void *extraParams,
void *result, int *dimension, int dimensionLength,
Nd4jLong *resultShapeInfo, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
void *extraParams, const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
int *dimension, uint64_t start, uint64_t stop);
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
#endif #endif
}; };
} }

View File

@ -58,15 +58,12 @@ namespace functions {
#ifdef __CUDACC__ #ifdef __CUDACC__
template<typename OpType> template<typename OpType>
static __device__ void transformCuda( static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo,
void *x, const void *y, const Nd4jLong *yShapeInfo,
Nd4jLong *xShapeInfo, void *result, const Nd4jLong *resultShapeInfo,
void *y, int *dimension, int dimensionLength,
Nd4jLong *yShapeInfo, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
void *result, const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template<typename OpType> template<typename OpType>
static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo,
@ -74,7 +71,13 @@ namespace functions {
void *z, const Nd4jLong *zShapeInfo); void *z, const Nd4jLong *zShapeInfo);
template <typename OpClass> template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
template <typename OpClass> template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream,
@ -82,7 +85,14 @@ namespace functions {
const void *y, const Nd4jLong *yShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo); void *z, const Nd4jLong *zShapeInfo);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream,
int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum,
const void *x, const Nd4jLong *xShapeInfo, const void *x, const Nd4jLong *xShapeInfo,
@ -90,59 +100,55 @@ namespace functions {
void *z, const Nd4jLong *zShapeInfo); void *z, const Nd4jLong *zShapeInfo);
template<typename OpType> template<typename OpType>
static __device__ void transformInverseCuda( static __device__ void transformInverseCuda(const void *x, const Nd4jLong *xShapeInfo,
void *x, const void *y, const Nd4jLong *yShapeInfo,
Nd4jLong *xShapeInfo, void *result, const Nd4jLong *resultShapeInfo,
void *y, int *dimension, int dimensionLength,
Nd4jLong *yShapeInfo, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
void *result, const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template <typename OpClass> template <typename OpClass>
static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream,
int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *result, const Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ);
#else #else
static void exec(int opNum, static void exec(int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *y, void *result, const Nd4jLong *resultShapeInfo,
Nd4jLong *yShapeInfo, int *dimension, int dimensionLength,
void *result, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
Nd4jLong *resultShapeInfo, const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
int *dimension, uint64_t start, uint64_t stop);
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
static void exec(const int opNum, static void exec(int opNum,
const void *x, const Nd4jLong *xShapeInfo, const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo); void *z, const Nd4jLong *zShapeInfo);
static void execInverse(int opNum, static void execInverse(int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *y, void *result, const Nd4jLong *resultShapeInfo,
Nd4jLong *yShapeInfo, int *dimension, int dimensionLength,
void *result, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
Nd4jLong *resultShapeInfo, const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
int *dimension, uint64_t start, uint64_t stop);
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
/** /**
* CPU execution * CPU execution
@ -156,20 +162,13 @@ namespace functions {
* @param dimensionLength the length of the dimension buffer * @param dimensionLength the length of the dimension buffer
*/ */
template<typename OpType> template<typename OpType>
static void exec(void *x, static void exec(const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *y, void *result, const Nd4jLong *resultShapeInfo,
Nd4jLong *yShapeInfo, int *dimension, int dimensionLength,
void *result, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
Nd4jLong *resultShapeInfo, const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
int *dimension, uint64_t start, uint64_t stop);
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
template<typename OpType> template<typename OpType>
static void exec(const void *x, const Nd4jLong *xShapeInfo, static void exec(const void *x, const Nd4jLong *xShapeInfo,
@ -177,20 +176,13 @@ namespace functions {
void *z, const Nd4jLong *zShapeInfo); void *z, const Nd4jLong *zShapeInfo);
template<typename OpType> template<typename OpType>
static void execInverse(void *x, static void execInverse(const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *y, void *result, const Nd4jLong *resultShapeInfo,
Nd4jLong *yShapeInfo, int *dimension, int dimensionLength,
void *result, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
Nd4jLong *resultShapeInfo, const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ,
int *dimension, uint64_t start, uint64_t stop);
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
#endif #endif
}; };
} }

View File

@ -34,20 +34,13 @@ namespace broadcast {
template <typename X, typename Y, typename Z> template <typename X, typename Y, typename Z>
void Broadcast<X, Y, Z>::execInverse(const int opNum, void Broadcast<X, Y, Z>::execInverse(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *y, void *z, const Nd4jLong *zShapeInfo,
Nd4jLong *yShapeInfo, int *dimension, int dimensionLength,
void *z, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset,
Nd4jLong *zShapeInfo, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
int *dimension, uint64_t start, uint64_t stop) {
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(execInverse, PARAMS(x, DISPATCH_BY_OPNUM_TTT(execInverse, PARAMS(x,
xShapeInfo, xShapeInfo,
y, y,
@ -64,21 +57,14 @@ namespace broadcast {
template <typename X, typename Y, typename Z> template <typename X, typename Y, typename Z>
void Broadcast<X, Y, Z>::exec(const int opNum, void Broadcast<X, Y, Z>::exec(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *y, void *z, const Nd4jLong *zShapeInfo,
Nd4jLong *yShapeInfo, int *dimension, int dimensionLength,
void *z, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset,
Nd4jLong *zShapeInfo, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
int *dimension, sd::LoopKind::Kind loopKind,
int dimensionLength, uint64_t start, uint64_t stop) {
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
sd::LoopKind::Kind loopKind,
uint64_t start,
uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
xShapeInfo, xShapeInfo,
y, y,
@ -96,24 +82,17 @@ namespace broadcast {
template <typename X, typename Y, typename Z> template <typename X, typename Y, typename Z>
template<typename OpType> template<typename OpType>
void Broadcast<X, Y, Z>::exec(void *vx, void Broadcast<X, Y, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo,
void *vy, void *vz, const Nd4jLong *zShapeInfo,
Nd4jLong *yShapeInfo, int *dimension, int dimensionLength,
void *vz, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset,
Nd4jLong *zShapeInfo, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
int *dimension, sd::LoopKind::Kind loopKind,
int dimensionLength, uint64_t start, uint64_t stop) {
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
sd::LoopKind::Kind loopKind,
uint64_t start,
uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<Y *>(vy); auto y = reinterpret_cast<const Y *>(vy);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
//decompose in to several sub tads after //decompose in to several sub tads after
@ -397,23 +376,16 @@ namespace broadcast {
template <typename X, typename Y, typename Z> template <typename X, typename Y, typename Z>
template<typename OpType> template<typename OpType>
void Broadcast<X, Y, Z>::execInverse(void *vx, void Broadcast<X, Y, Z>::execInverse(const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo,
void *vy, void *vz, const Nd4jLong *zShapeInfo,
Nd4jLong *yShapeInfo, int *dimension, int dimensionLength,
void *vz, const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffset,
Nd4jLong *zShapeInfo, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
int *dimension, uint64_t start, uint64_t stop) {
int dimensionLength,
Nd4jLong *yTadShapeInfo,
Nd4jLong *yTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<Y *>(vy); auto y = reinterpret_cast<const Y *>(vy);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
//decompose in to several sub tads after //decompose in to several sub tads after

View File

@ -33,21 +33,14 @@ namespace broadcast {
template <typename X, typename Y> template <typename X, typename Y>
void BroadcastBool<X, Y>::exec(const int opNum, void BroadcastBool<X, Y>::exec(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *y, void *z, const Nd4jLong *zShapeInfo,
Nd4jLong *yShapeInfo, void *extraParams,
void *z, int *dimension, int dimensionLength,
Nd4jLong *zShapeInfo, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset,
void *extraParams, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
int *dimension, uint64_t start, uint64_t stop) {
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
xShapeInfo, xShapeInfo,
y, y,
@ -75,21 +68,14 @@ namespace broadcast {
template <typename X, typename Y> template <typename X, typename Y>
void BroadcastBool<X, Y>::execInverse(const int opNum, void BroadcastBool<X, Y>::execInverse(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *y, void *z, const Nd4jLong *zShapeInfo,
Nd4jLong *yShapeInfo, void *extraParams,
void *z, int *dimension, int dimensionLength,
Nd4jLong *zShapeInfo, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset,
void *extraParams, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
int *dimension, uint64_t start, uint64_t stop) {
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
DISPATCH_BY_OPNUM_TT(execInverse, PARAMS(x, DISPATCH_BY_OPNUM_TT(execInverse, PARAMS(x,
xShapeInfo, xShapeInfo,
y, y,
@ -107,24 +93,17 @@ namespace broadcast {
template <typename X, typename Z> template <typename X, typename Z>
template<typename OpType> template<typename OpType>
void BroadcastBool<X, Z>::exec(void *vx, void BroadcastBool<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo,
void *vy, void *vz, const Nd4jLong *zShapeInfo,
Nd4jLong *yShapeInfo, void *vextraParams,
void *vz, int *dimension, int dimensionLength,
Nd4jLong *zShapeInfo, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset,
void *vextraParams, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
int *dimension, uint64_t start, uint64_t stop) {
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams); auto extraParams = reinterpret_cast<X*>(vextraParams);
@ -138,8 +117,8 @@ namespace broadcast {
if (xTadShapeInfo == nullptr || tadOffsets == nullptr) { if (xTadShapeInfo == nullptr || tadOffsets == nullptr) {
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
xTadShapeShapeInfo = tadPack.primaryShapeInfo(); xTadShapeShapeInfo = const_cast<Nd4jLong*>(tadPack.primaryShapeInfo());
tadOffsets = tadPack.primaryOffsets(); tadOffsets = const_cast<Nd4jLong*>(tadPack.primaryOffsets());
} }
//int *resultStride = shape::stride(xTadShapeShapeInfo); //int *resultStride = shape::stride(xTadShapeShapeInfo);
@ -279,24 +258,17 @@ namespace broadcast {
template <typename X, typename Z> template <typename X, typename Z>
template<typename OpType> template<typename OpType>
void BroadcastBool<X, Z>::execInverse(void *vx, void BroadcastBool<X, Z>::execInverse(const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo,
void *vy, void *vz, const Nd4jLong *zShapeInfo,
Nd4jLong *yShapeInfo, void *vextraParams,
void *vz, int *dimension, int dimensionLength,
Nd4jLong *zShapeInfo, const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffset,
void *vextraParams, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
int *dimension, uint64_t start, uint64_t stop) {
int dimensionLength,
Nd4jLong *yTadShapeInfo,
Nd4jLong *yTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams); auto extraParams = reinterpret_cast<X*>(vextraParams);
@ -310,8 +282,8 @@ namespace broadcast {
if (yTadShapeInfo == nullptr || tadOffsets == nullptr) { if (yTadShapeInfo == nullptr || tadOffsets == nullptr) {
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dimension, dimensionLength); auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dimension, dimensionLength);
yTadShapeShapeInfo = tadPack.primaryShapeInfo(); yTadShapeShapeInfo = const_cast<Nd4jLong*>(tadPack.primaryShapeInfo());
tadOffsets = tadPack.primaryOffsets(); tadOffsets = const_cast<Nd4jLong*>(tadPack.primaryOffsets());
} }
//int *resultStride = shape::stride(yTadShapeShapeInfo); //int *resultStride = shape::stride(yTadShapeShapeInfo);

View File

@ -33,20 +33,13 @@ namespace functions {
template <typename X> template <typename X>
void BroadcastInt<X>::exec(const int opNum, void BroadcastInt<X>::exec(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *y, void *z, const Nd4jLong *zShapeInfo,
Nd4jLong *yShapeInfo, int *dimension, int dimensionLength,
void *z, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset,
Nd4jLong *zShapeInfo, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
int *dimension, uint64_t start, uint64_t stop) {
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
xShapeInfo, xShapeInfo,
y, y,
@ -72,20 +65,13 @@ namespace functions {
template <typename X> template <typename X>
void BroadcastInt<X>::execInverse(const int opNum, void BroadcastInt<X>::execInverse(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo,
void *y, void *z, const Nd4jLong *zShapeInfo,
Nd4jLong *yShapeInfo, int *dimension, int dimensionLength,
void *z, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset,
Nd4jLong *zShapeInfo, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
int *dimension, uint64_t start, uint64_t stop) {
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
DISPATCH_BY_OPNUM_T(execInverse, PARAMS(x, DISPATCH_BY_OPNUM_T(execInverse, PARAMS(x,
xShapeInfo, xShapeInfo,
y, y,
@ -102,23 +88,16 @@ namespace functions {
template <typename X> template <typename X>
template<typename OpType> template<typename OpType>
void BroadcastInt<X>::exec(void *vx, void BroadcastInt<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo,
void *vy, void *vz, const Nd4jLong *zShapeInfo,
Nd4jLong *yShapeInfo, int *dimension, int dimensionLength,
void *vz, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset,
Nd4jLong *zShapeInfo, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
int *dimension, uint64_t start, uint64_t stop) {
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
//decompose in to several sub tads after //decompose in to several sub tads after
@ -131,8 +110,8 @@ namespace functions {
if (xTadShapeInfo == nullptr || tadOffsets == nullptr) { if (xTadShapeInfo == nullptr || tadOffsets == nullptr) {
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
xTadShapeShapeInfo = tadPack.primaryShapeInfo(); xTadShapeShapeInfo = const_cast<Nd4jLong*>(tadPack.primaryShapeInfo());
tadOffsets = tadPack.primaryOffsets(); tadOffsets = const_cast<Nd4jLong*>(tadPack.primaryOffsets());
} }
//int *resultStride = shape::stride(xTadShapeShapeInfo); //int *resultStride = shape::stride(xTadShapeShapeInfo);
@ -272,23 +251,16 @@ namespace functions {
template <typename X> template <typename X>
template<typename OpType> template<typename OpType>
void BroadcastInt<X>::execInverse(void *vx, void BroadcastInt<X>::execInverse(const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo,
void *vy, void *vz, const Nd4jLong *zShapeInfo,
Nd4jLong *yShapeInfo, int *dimension, const int dimensionLength,
void *vz, const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffset,
Nd4jLong *zShapeInfo, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset,
int *dimension, uint64_t start, uint64_t stop) {
int dimensionLength,
Nd4jLong *yTadShapeInfo,
Nd4jLong *yTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
//decompose in to several sub tads after //decompose in to several sub tads after
@ -301,8 +273,8 @@ namespace functions {
if (yTadShapeInfo == nullptr || tadOffsets == nullptr) { if (yTadShapeInfo == nullptr || tadOffsets == nullptr) {
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dimension, dimensionLength); auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dimension, dimensionLength);
yTadShapeShapeInfo = tadPack.primaryShapeInfo(); yTadShapeShapeInfo = const_cast<Nd4jLong*>(tadPack.primaryShapeInfo());
tadOffsets = tadPack.primaryOffsets(); tadOffsets = const_cast<Nd4jLong*>(tadPack.primaryOffsets());
} }
//int *resultStride = shape::stride(yTadShapeShapeInfo); //int *resultStride = shape::stride(yTadShapeShapeInfo);

View File

@ -33,27 +33,27 @@ namespace indexreduce {
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
Nd4jLong IndexReduce<X,Y>::execScalar( const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) { Nd4jLong IndexReduce<X,Y>::execScalar( const int opNum, const void *x, const Nd4jLong *xShapeInfo, void *extraParams) {
RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), INDEX_REDUCE_OPS); RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), INDEX_REDUCE_OPS);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
void IndexReduce<X,Y>::exec(const int opNum, void IndexReduce<X,Y>::exec(const int opNum,
void *x, Nd4jLong *xShapeInfo, const void *x, const Nd4jLong *xShapeInfo,
void *extraParams, void *extraParams,
void *z, Nd4jLong *zShapeInfo, void *z, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS); DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
template<typename OpType> template<typename OpType>
Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams) { Nd4jLong IndexReduce<X, Y>::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
//T startingVal = OpType::startingValue(x); //T startingVal = OpType::startingValue(x);
@ -107,13 +107,13 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename X, typename Z> template <typename X, typename Z>
template<typename OpType> template<typename OpType>
void IndexReduce<X, Z>::exec(void *vx, Nd4jLong *xShapeInfo, void IndexReduce<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams, void *vextraParams,
void *vz, Nd4jLong *zShapeInfo, void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -136,7 +136,7 @@ void IndexReduce<X, Z>::exec(void *vx, Nd4jLong *xShapeInfo,
} }
auto tadOnlyShapeInfo = tadShapeInfo; auto tadOnlyShapeInfo = tadShapeInfo;
Nd4jLong *tadOffsets = tadOffset; auto tadOffsets = tadOffset;
if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) {
if (dimensionLength < 1) if (dimensionLength < 1)

View File

@ -34,18 +34,13 @@ namespace functions {
namespace pairwise_transforms { namespace pairwise_transforms {
template <typename X, typename Y, typename Z> template <typename X, typename Y, typename Z>
void PairWiseTransform<X, Y, Z>::exec( void PairWiseTransform<X, Y, Z>::exec(const int opNum,
const int opNum, const void *x, Nd4jLong xEws,
void *x, const void *y, Nd4jLong yEws,
Nd4jLong xEws, void *z, Nd4jLong zEws,
void *y, void *extraParams,
Nd4jLong yEws, Nd4jLong n,
void *z, const uint64_t start,const uint64_t stop) {
Nd4jLong zEws,
void *extraParams,
Nd4jLong n,
const uint64_t start,
const uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
xEws, xEws,
y, y,
@ -60,16 +55,16 @@ namespace functions {
template <typename X, typename Y, typename Z> template <typename X, typename Y, typename Z>
template <typename OpType> template <typename OpType>
void PairWiseTransform<X, Y, Z>::exec(void *vx, Nd4jLong xEws, void PairWiseTransform<X, Y, Z>::exec(const void *vx, Nd4jLong xEws,
void *vy, Nd4jLong yEws, const void *vy, Nd4jLong yEws,
void *vz, Nd4jLong zEws, void *vz, Nd4jLong zEws,
void *vextraParams, void *vextraParams,
const Nd4jLong n, const Nd4jLong n,
const uint64_t start, const uint64_t start,
const uint64_t stop) { const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<Y *>(vy); auto y = reinterpret_cast<const Y *>(vy);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);
@ -86,17 +81,12 @@ namespace functions {
} }
template <typename X, typename Y, typename Z> template <typename X, typename Y, typename Z>
void PairWiseTransform<X, Y, Z>::exec( void PairWiseTransform<X, Y, Z>::exec(const int opNum,
const int opNum, const void *x, const Nd4jLong *xShapeInfo,
void *x, const void *y, const Nd4jLong *yShapeInfo,
Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo,
void *y, void *extraParams,
Nd4jLong *yShapeInfo, const uint64_t start, const uint64_t stop) {
void *z,
Nd4jLong *zShapeInfo,
void *extraParams,
const uint64_t start,
const uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
xShapeInfo, xShapeInfo,
y, y,
@ -110,19 +100,14 @@ namespace functions {
template <typename X, typename Y, typename Z> template <typename X, typename Y, typename Z>
template <typename OpType> template <typename OpType>
void PairWiseTransform<X, Y, Z>::exec( void PairWiseTransform<X, Y, Z>::exec(const void *vx, const Nd4jLong* xShapeInfo,
void *vx, const void *vy, const Nd4jLong* yShapeInfo,
Nd4jLong* xShapeInfo, void *vz, const Nd4jLong* zShapeInfo,
void *vy, void *vextraParams,
Nd4jLong* yShapeInfo, const uint64_t start, const uint64_t stop) {
void *vz,
Nd4jLong* zShapeInfo,
void *vextraParams,
const uint64_t start,
const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<Y *>(vy); auto y = reinterpret_cast<const Y *>(vy);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);

View File

@ -30,18 +30,13 @@ namespace functions {
namespace pairwise_transforms { namespace pairwise_transforms {
template <typename X, typename Y> template <typename X, typename Y>
void PairWiseBoolTransform<X, Y>::exec( void PairWiseBoolTransform<X, Y>::exec(const int opNum,
const int opNum, const void *x, Nd4jLong xEws,
void *x, const void *y, Nd4jLong yEws,
Nd4jLong xEws, void *z, Nd4jLong zEws,
void *y, void *extraParams,
Nd4jLong yEws, Nd4jLong n,
void *z, const uint64_t start, const uint64_t stop) {
Nd4jLong zEws,
void *extraParams,
Nd4jLong n,
const uint64_t start,
const uint64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
xEws, xEws,
y, y,
@ -56,19 +51,15 @@ namespace functions {
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType> template <typename OpType>
void PairWiseBoolTransform<X, Z>::exec(void *vx, void PairWiseBoolTransform<X, Z>::exec(const void *vx, Nd4jLong xEws,
Nd4jLong xEws, const void *vy, Nd4jLong yEws,
void *vy, void *vz, Nd4jLong zEws,
Nd4jLong yEws, void *vextraParams,
void *vz, const Nd4jLong n,
Nd4jLong zEws, const uint64_t start, const uint64_t stop) {
void *vextraParams,
const Nd4jLong n,
const uint64_t start,
const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -85,17 +76,12 @@ namespace functions {
} }
template <typename X, typename Y> template <typename X, typename Y>
void PairWiseBoolTransform<X, Y>::exec( void PairWiseBoolTransform<X, Y>::exec(const int opNum,
const int opNum, const void *x, const Nd4jLong *xShapeInfo,
void *x, const void *y, const Nd4jLong *yShapeInfo,
Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo,
void *y, void *extraParams,
Nd4jLong *yShapeInfo, const uint64_t start,const uint64_t stop) {
void *z,
Nd4jLong *zShapeInfo,
void *extraParams,
const uint64_t start,
const uint64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
xShapeInfo, xShapeInfo,
y, y,
@ -109,15 +95,14 @@ namespace functions {
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType> template <typename OpType>
void PairWiseBoolTransform<X, Z>::exec(void *vx, Nd4jLong* xShapeInfo, void PairWiseBoolTransform<X, Z>::exec(const void *vx, const Nd4jLong* xShapeInfo,
void *vy, Nd4jLong* yShapeInfo, const void *vy, const Nd4jLong* yShapeInfo,
void *vz, Nd4jLong* zShapeInfo, void *vz, const Nd4jLong* zShapeInfo,
void *vextraParams, void *vextraParams,
const uint64_t start, const uint64_t start, const uint64_t stop) {
const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);

View File

@ -30,18 +30,13 @@ namespace functions {
namespace pairwise_transforms { namespace pairwise_transforms {
template <typename X> template <typename X>
void PairWiseIntTransform<X>::exec( void PairWiseIntTransform<X>::exec(const int opNum,
const int opNum, const void *x, Nd4jLong xEws,
void *x, const void *y, Nd4jLong yEws,
Nd4jLong xEws, void *z, Nd4jLong zEws,
void *y, void *extraParams,
Nd4jLong yEws, Nd4jLong n,
void *z, const uint64_t start, const uint64_t stop) {
Nd4jLong zEws,
void *extraParams,
Nd4jLong n,
const uint64_t start,
const uint64_t stop) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
xEws, xEws,
y, y,
@ -56,19 +51,15 @@ namespace functions {
template <typename X> template <typename X>
template <typename OpType> template <typename OpType>
void PairWiseIntTransform<X>::exec(void *vx, void PairWiseIntTransform<X>::exec(const void *vx, Nd4jLong xEws,
Nd4jLong xEws, const void *vy, Nd4jLong yEws,
void *vy, void *vz, Nd4jLong zEws,
Nd4jLong yEws, void *vextraParams,
void *vz, const Nd4jLong n,
Nd4jLong zEws, const uint64_t start,
void *vextraParams, const uint64_t stop) {
const Nd4jLong n, auto x = reinterpret_cast<const X *>(vx);
const uint64_t start, auto y = reinterpret_cast<const X *>(vy);
const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -85,17 +76,12 @@ namespace functions {
} }
template <typename X> template <typename X>
void PairWiseIntTransform<X>::exec( void PairWiseIntTransform<X>::exec(const int opNum,
const int opNum, const void *x, const Nd4jLong *xShapeInfo,
void *x, const void *y, const Nd4jLong *yShapeInfo,
Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo,
void *y, void *extraParams,
Nd4jLong *yShapeInfo, const uint64_t start, const uint64_t stop) {
void *z,
Nd4jLong *zShapeInfo,
void *extraParams,
const uint64_t start,
const uint64_t stop) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
xShapeInfo, xShapeInfo,
y, y,
@ -109,15 +95,15 @@ namespace functions {
template <typename X> template <typename X>
template <typename OpType> template <typename OpType>
void PairWiseIntTransform<X>::exec(void *vx, Nd4jLong* xShapeInfo, void PairWiseIntTransform<X>::exec(const void *vx, const Nd4jLong* xShapeInfo,
void *vy, Nd4jLong* yShapeInfo, const void *vy, const Nd4jLong* yShapeInfo,
void *vz, Nd4jLong* zShapeInfo, void *vz, const Nd4jLong* zShapeInfo,
void *vextraParams, void *vextraParams,
const uint64_t start, const uint64_t start,
const uint64_t stop) { const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);

View File

@ -33,16 +33,13 @@ namespace functions {
template<typename X> template<typename X>
template<typename OpClass> template<typename OpClass>
void RandomFunction<X>::execTransform(Nd4jPointer state, void RandomFunction<X>::execTransform(Nd4jPointer state,
void *vx, const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo,
void *vy, void *vz, const Nd4jLong *zShapeInfo,
Nd4jLong *yShapeInfo, void *vextraArguments) {
void *vz,
Nd4jLong *zShapeInfo,
void *vextraArguments) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto extraArguments = reinterpret_cast<X *>(vextraArguments); auto extraArguments = reinterpret_cast<X *>(vextraArguments);
@ -166,12 +163,10 @@ namespace functions {
template<typename X> template<typename X>
template<typename OpClass> template<typename OpClass>
void RandomFunction<X>::execTransform(Nd4jPointer state, void RandomFunction<X>::execTransform(Nd4jPointer state,
void *vx, const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo,
void *vz, void *vextraArguments) {
Nd4jLong *zShapeInfo, auto x = reinterpret_cast<const X *>(vx);
void *vextraArguments) {
auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto extraArguments = reinterpret_cast<X *>(vextraArguments); auto extraArguments = reinterpret_cast<X *>(vextraArguments);
@ -227,7 +222,7 @@ namespace functions {
template<typename X> template<typename X>
template<typename OpClass> template<typename OpClass>
void RandomFunction<X>::execTransform(Nd4jPointer state, void *vz, Nd4jLong *zShapeInfo, void *vextraArguments) { void RandomFunction<X>::execTransform(Nd4jPointer state, void *vz, const Nd4jLong *zShapeInfo, void *vextraArguments) {
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto extraArguments = reinterpret_cast<X *>(vextraArguments); auto extraArguments = reinterpret_cast<X *>(vextraArguments);
@ -266,17 +261,17 @@ namespace functions {
} }
template<typename X> template<typename X>
void RandomFunction<X>::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraArguments) { void RandomFunction<X>::execTransform(int opNum, Nd4jPointer state, const void *x, const Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo, void *extraArguments) {
DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, x, xShapeInfo, z, zShapeInfo, extraArguments), RANDOM_OPS) DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, x, xShapeInfo, z, zShapeInfo, extraArguments), RANDOM_OPS)
} }
template<typename X> template<typename X>
void RandomFunction<X>::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraArguments) { void RandomFunction<X>::execTransform(int opNum, Nd4jPointer state, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, void *extraArguments) {
DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraArguments), RANDOM_OPS) DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraArguments), RANDOM_OPS)
} }
template<typename X> template<typename X>
void RandomFunction<X>::execTransform(int opNum, Nd4jPointer state, void *z, Nd4jLong *zShapeInfo, void *extraArguments) { void RandomFunction<X>::execTransform(int opNum, Nd4jPointer state, void *z, const Nd4jLong *zShapeInfo, void *extraArguments) {
DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, z, zShapeInfo, extraArguments), RANDOM_OPS) DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, z, zShapeInfo, extraArguments), RANDOM_OPS)
} }

View File

@ -33,12 +33,10 @@ namespace functions {
namespace reduce { namespace reduce {
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType> template <typename OpType>
void _CUDA_H ReduceBoolFunction<X,Z>::execScalar(void *vx, void _CUDA_H ReduceBoolFunction<X,Z>::execScalar(const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *vextraParams,
void *vextraParams, void *vz, const Nd4jLong *zShapeInfo) {
void *vz, auto x = reinterpret_cast<const X *>(vx);
Nd4jLong *zShapeInfo) {
auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -78,9 +76,9 @@ namespace functions {
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType> template <typename OpType>
Z _CUDA_H ReduceBoolFunction<X, Z>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams) { Z _CUDA_H ReduceBoolFunction<X, Z>::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
const Nd4jLong length = shape::length(xShapeInfo); const Nd4jLong length = shape::length(xShapeInfo);
@ -103,49 +101,39 @@ namespace functions {
template <typename X, typename Y> template <typename X, typename Y>
Y ReduceBoolFunction<X, Y>::execScalar(const int opNum, Y ReduceBoolFunction<X, Y>::execScalar(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams) {
void *extraParams) {
RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_BOOL_OPS); RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_BOOL_OPS);
} }
template <typename X, typename Y> template <typename X, typename Y>
void ReduceBoolFunction<X, Y>::execScalar(const int opNum, void ReduceBoolFunction<X, Y>::execScalar(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *extraParams, void *z, const Nd4jLong *zShapeInfo) {
void *z,
Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_BOOL_OPS); DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_BOOL_OPS);
} }
template <typename X, typename Y> template <typename X, typename Y>
void ReduceBoolFunction<X, Y>::exec(const int opNum, void ReduceBoolFunction<X, Y>::exec(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *extraParams, void *z, const Nd4jLong *zShapeInfo,
void *z, int *dimension, int dimensionLength,
Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
int *dimension, int64_t start, int64_t stop) {
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset, start, stop), REDUCE_BOOL_OPS); DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset, start, stop), REDUCE_BOOL_OPS);
} }
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType> template <typename OpType>
void _CUDA_H ReduceBoolFunction<X,Z>::exec(void *vx, void _CUDA_H ReduceBoolFunction<X,Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *vextraParams,
void *vextraParams, void *vresult, const Nd4jLong *zShapeInfo,
void *vresult, int *dimension, int dimensionLength,
Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, int64_t start, int64_t stop) {
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vresult); auto z = reinterpret_cast<Z *>(vresult);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -193,20 +181,17 @@ namespace functions {
template <typename X, typename Z> template <typename X, typename Z>
template<typename OpType> template<typename OpType>
void _CUDA_H ReduceBoolFunction<X,Z>::exec(void *x, void _CUDA_H ReduceBoolFunction<X,Z>::exec(const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *extraParams, void *vresult, const Nd4jLong *resultShapeInfo) {
void *vresult,
Nd4jLong *resultShapeInfo) {
// FIXME: wtf???
auto z = reinterpret_cast<Z*>(vresult); auto z = reinterpret_cast<Z*>(vresult);
z[0] = execScalar<OpType>(x, xShapeInfo, extraParams); z[0] = execScalar<OpType>(x, xShapeInfo, extraParams);
} }
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType> template <typename OpType>
Z _CUDA_H ReduceBoolFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) { Z _CUDA_H ReduceBoolFunction<X, Z>::execScalar(const void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads()); int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
Z intermediate[64]; Z intermediate[64];

View File

@ -33,12 +33,10 @@ namespace functions {
namespace reduce { namespace reduce {
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType> template <typename OpType>
void _CUDA_H ReduceFloatFunction<X,Z>::execScalar(void *vx, void _CUDA_H ReduceFloatFunction<X,Z>::execScalar(const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *vextraParams,
void *vextraParams, void *vz, const Nd4jLong *zShapeInfo) {
void *vz, auto x = reinterpret_cast<const X *>(vx);
Nd4jLong *zShapeInfo) {
auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);
@ -98,8 +96,8 @@ namespace functions {
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType> template <typename OpType>
Z _CUDA_H ReduceFloatFunction<X, Z>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams) { Z _CUDA_H ReduceFloatFunction<X, Z>::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);
const Nd4jLong length = shape::length(xShapeInfo); const Nd4jLong length = shape::length(xShapeInfo);
@ -122,33 +120,27 @@ namespace functions {
template <typename X, typename Y> template <typename X, typename Y>
Y ReduceFloatFunction<X, Y>::execScalar(const int opNum, Y ReduceFloatFunction<X, Y>::execScalar(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams) {
void *extraParams) {
RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_FLOAT_OPS); RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_FLOAT_OPS);
} }
template <typename X, typename Y> template <typename X, typename Y>
void ReduceFloatFunction<X, Y>::execScalar(const int opNum, void ReduceFloatFunction<X, Y>::execScalar(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *extraParams, void *z, const Nd4jLong *zShapeInfo) {
void *z,
Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_FLOAT_OPS); DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_FLOAT_OPS);
} }
template <typename X, typename Y> template <typename X, typename Y>
void ReduceFloatFunction<X, Y>::exec(const int opNum, void ReduceFloatFunction<X, Y>::exec(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *extraParams, void *z, const Nd4jLong *zShapeInfo,
void *z, int *dimension, int dimensionLength,
Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
int *dimension, int64_t start, int64_t stop) {
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
xShapeInfo, xShapeInfo,
extraParams, extraParams,
@ -163,17 +155,14 @@ namespace functions {
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType> template <typename OpType>
void _CUDA_H ReduceFloatFunction<X,Z>::exec(void *vx, void _CUDA_H ReduceFloatFunction<X,Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *vextraParams,
void *vextraParams, void *vresult, const Nd4jLong *zShapeInfo,
void *vresult, int *dimension, int dimensionLength,
Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
int *dimension, int64_t start, int64_t stop) {
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vresult); auto z = reinterpret_cast<Z *>(vresult);
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);
@ -226,11 +215,9 @@ namespace functions {
template <typename X, typename Z> template <typename X, typename Z>
template<typename OpType> template<typename OpType>
void _CUDA_H ReduceFloatFunction<X,Z>::exec(void *x, void _CUDA_H ReduceFloatFunction<X,Z>::exec(const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *extraParams, void *vresult, const Nd4jLong *resultShapeInfo) {
void *vresult,
Nd4jLong *resultShapeInfo) {
// FIXME: wtf??? // FIXME: wtf???
auto z = reinterpret_cast<Z*>(vresult); auto z = reinterpret_cast<Z*>(vresult);
z[0] = execScalar<OpType>(x, xShapeInfo, extraParams); z[0] = execScalar<OpType>(x, xShapeInfo, extraParams);
@ -238,9 +225,9 @@ namespace functions {
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType> template <typename OpType>
Z _CUDA_H ReduceFloatFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) { Z _CUDA_H ReduceFloatFunction<X, Z>::execScalar(const void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads()); int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
Z intermediate[64]; Z intermediate[64];

View File

@ -33,12 +33,10 @@ namespace functions {
namespace reduce { namespace reduce {
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType> template <typename OpType>
void _CUDA_H ReduceLongFunction<X,Z>::execScalar(void *vx, void _CUDA_H ReduceLongFunction<X,Z>::execScalar(const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *vextraParams,
void *vextraParams, void *vz, const Nd4jLong *zShapeInfo) {
void *vz, auto x = reinterpret_cast<const X *>(vx);
Nd4jLong *zShapeInfo) {
auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -93,10 +91,8 @@ namespace functions {
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType> template <typename OpType>
Z _CUDA_H ReduceLongFunction<X, Z>::execScalar(void *vx, Z _CUDA_H ReduceLongFunction<X, Z>::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) {
Nd4jLong *xShapeInfo, auto x = reinterpret_cast<const X *>(vx);
void *vextraParams) {
auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
const Nd4jLong length = shape::length(xShapeInfo); const Nd4jLong length = shape::length(xShapeInfo);
@ -120,49 +116,40 @@ namespace functions {
template <typename X, typename Y> template <typename X, typename Y>
Y ReduceLongFunction<X, Y>::execScalar(const int opNum, Y ReduceLongFunction<X, Y>::execScalar(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams) {
void *extraParams) {
RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_LONG_OPS); RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_LONG_OPS);
} }
template <typename X, typename Y> template <typename X, typename Y>
void ReduceLongFunction<X, Y>::execScalar(const int opNum, void ReduceLongFunction<X, Y>::execScalar(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *extraParams, void *z, const Nd4jLong *zShapeInfo) {
void *z,
Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_LONG_OPS); DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_LONG_OPS);
} }
template <typename X, typename Y> template <typename X, typename Y>
void ReduceLongFunction<X, Y>::exec(const int opNum, void ReduceLongFunction<X, Y>::exec(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *extraParams, void *z, const Nd4jLong *zShapeInfo,
void *z, int *dimension, int dimensionLength,
Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
int *dimension, int64_t start, int64_t stop) {
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset, start, stop), REDUCE_LONG_OPS); DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset, start, stop), REDUCE_LONG_OPS);
} }
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType> template <typename OpType>
void _CUDA_H ReduceLongFunction<X,Z>::exec(void *vx, void _CUDA_H ReduceLongFunction<X,Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *vextraParams,
void *vextraParams, void *vresult, const Nd4jLong *zShapeInfo,
void *vresult, int *dimension, int dimensionLength,
Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
int *dimension, int64_t start, int64_t stop) {
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vresult); auto z = reinterpret_cast<Z *>(vresult);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -215,21 +202,18 @@ namespace functions {
template <typename X, typename Z> template <typename X, typename Z>
template<typename OpType> template<typename OpType>
void _CUDA_H ReduceLongFunction<X,Z>::exec(void *x, void _CUDA_H ReduceLongFunction<X,Z>::exec(const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *extraParams, void *vresult, const Nd4jLong *resultShapeInfo) {
void *vresult,
Nd4jLong *resultShapeInfo) {
// FIXME: wtf???
auto z = reinterpret_cast<Z*>(vresult); auto z = reinterpret_cast<Z*>(vresult);
z[0] = execScalar<OpType>(x, xShapeInfo, extraParams); z[0] = execScalar<OpType>(x, xShapeInfo, extraParams);
} }
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType> template <typename OpType>
Z _CUDA_H ReduceLongFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) { Z _CUDA_H ReduceLongFunction<X, Z>::execScalar(const void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads()); int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
Z intermediate[64]; Z intermediate[64];

View File

@ -34,12 +34,10 @@ namespace functions {
namespace reduce { namespace reduce {
template <typename X> template <typename X>
template <typename OpType> template <typename OpType>
void _CUDA_H ReduceSameFunction<X>::execScalar(void *vx, void _CUDA_H ReduceSameFunction<X>::execScalar(const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *vextraParams,
void *vextraParams, void *vz, const Nd4jLong *zShapeInfo) {
void *vz, auto x = reinterpret_cast<const X *>(vx);
Nd4jLong *zShapeInfo) {
auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -95,10 +93,8 @@ namespace functions {
template <typename X> template <typename X>
template <typename OpType> template <typename OpType>
X _CUDA_H ReduceSameFunction<X>::execScalar(void *vx, X _CUDA_H ReduceSameFunction<X>::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) {
Nd4jLong *xShapeInfo, auto x = reinterpret_cast<const X *>(vx);
void *vextraParams) {
auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
const Nd4jLong length = shape::length(xShapeInfo); const Nd4jLong length = shape::length(xShapeInfo);
@ -120,33 +116,27 @@ namespace functions {
template <typename X> template <typename X>
X ReduceSameFunction<X>::execScalar(const int opNum, X ReduceSameFunction<X>::execScalar(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams) {
void *extraParams) {
RETURNING_DISPATCH_BY_OPNUM_T(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_SAME_OPS); RETURNING_DISPATCH_BY_OPNUM_T(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_SAME_OPS);
} }
template <typename X> template <typename X>
void ReduceSameFunction<X>::execScalar(const int opNum, void ReduceSameFunction<X>::execScalar(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *extraParams, void *z, const Nd4jLong *zShapeInfo) {
void *z,
Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_T(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_SAME_OPS); DISPATCH_BY_OPNUM_T(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_SAME_OPS);
} }
template <typename X> template <typename X>
void ReduceSameFunction<X>::exec(const int opNum, void ReduceSameFunction<X>::exec(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *extraParams, void *z, const Nd4jLong *zShapeInfo,
void *z, int *dimension, int dimensionLength,
Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
int *dimension, int64_t start, int64_t stop) {
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
xShapeInfo, xShapeInfo,
extraParams, extraParams,
@ -161,17 +151,14 @@ namespace functions {
template <typename X> template <typename X>
template <typename OpType> template <typename OpType>
void _CUDA_H ReduceSameFunction<X>::exec(void *vx, void _CUDA_H ReduceSameFunction<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *vextraParams,
void *vextraParams, void *vz, const Nd4jLong *zShapeInfo,
void *vz, int *dimension, int dimensionLength,
Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
int *dimension, int64_t start, int64_t stop) {
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -224,21 +211,18 @@ namespace functions {
template <typename X> template <typename X>
template<typename OpType> template<typename OpType>
void _CUDA_H ReduceSameFunction<X>::exec(void *x, void _CUDA_H ReduceSameFunction<X>::exec(const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *extraParams, void *vz, const Nd4jLong *zShapeInfo) {
void *vz,
Nd4jLong *zShapeInfo) {
// FIXME: wtf???
auto z = reinterpret_cast<X*>(vz); auto z = reinterpret_cast<X*>(vz);
z[0] = execScalar<OpType>(x, xShapeInfo, extraParams); z[0] = execScalar<OpType>(x, xShapeInfo, extraParams);
} }
template <typename X> template <typename X>
template <typename OpType> template <typename OpType>
X _CUDA_H ReduceSameFunction<X>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) { X _CUDA_H ReduceSameFunction<X>::execScalar(const void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads()); int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
X intermediate[64]; X intermediate[64];

View File

@ -34,13 +34,13 @@ namespace reduce3 {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Z> template <typename X, typename Z>
template<typename OpType> template<typename OpType>
void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo, void Reduce3<X,Z>::execScalar(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams, void *vextraParams,
void *vy, Nd4jLong *yShapeInfo, const void *vy, const Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo) { void *vz, const Nd4jLong *zShapeInfo) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);
@ -134,10 +134,10 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
void Reduce3<X,Y>::execScalar(const int opNum, void Reduce3<X,Y>::execScalar(const int opNum,
void *vx, Nd4jLong *xShapeInfo, const void *vx, const Nd4jLong *xShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *vy, Nd4jLong *yShapeInfo, const void *vy, const Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo) { void *vz, const Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo), REDUCE3_OPS); DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo), REDUCE3_OPS);
} }
@ -146,14 +146,15 @@ void Reduce3<X,Y>::execScalar(const int opNum,
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Z> template <typename X, typename Z>
template<typename OpType> template<typename OpType>
void Reduce3<X,Z>::exec(void *vx, Nd4jLong *xShapeInfo, void Reduce3<X,Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams, void *vextraParams,
void *vy, Nd4jLong *yShapeInfo, const void *vy, const Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength, int64_t start, int64_t stop) { int *dimension, int dimensionLength,
int64_t start, int64_t stop) {
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<const X*>(vx);
auto y = reinterpret_cast<X*>(vy); auto y = reinterpret_cast<const X*>(vy);
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
auto extraParams = reinterpret_cast<Z*>(vextraParams); auto extraParams = reinterpret_cast<Z*>(vextraParams);
@ -171,15 +172,16 @@ void Reduce3<X,Z>::exec(void *vx, Nd4jLong *xShapeInfo,
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Z> template <typename X, typename Z>
template<typename OpType> template<typename OpType>
void Reduce3<X,Z>::exec(void *vx, Nd4jLong *xShapeInfo, void Reduce3<X,Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams, void *vextraParams,
void *vy, Nd4jLong *yShapeInfo, const void *vy, const Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, int64_t start, int64_t stop) { const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
int64_t start, int64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);
#ifdef INLINE_LOOPS #ifdef INLINE_LOOPS
@ -193,16 +195,17 @@ void Reduce3<X,Z>::exec(void *vx, Nd4jLong *xShapeInfo,
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Z> template <typename X, typename Z>
template<typename OpType> template<typename OpType>
void Reduce3<X,Z>:: execAll(void *vx, Nd4jLong *xShapeInfo, void Reduce3<X,Z>:: execAll(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams, void *vextraParams,
void *vy, Nd4jLong *yShapeInfo, const void *vy, const Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets, int64_t start, int64_t stop) { const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets,
int64_t start, int64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<const X *>(vy);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z*>(vextraParams); auto extraParams = reinterpret_cast<Z*>(vextraParams);
@ -215,12 +218,13 @@ void Reduce3<X,Z>:: execAll(void *vx, Nd4jLong *xShapeInfo,
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
void Reduce3<X,Y>::exec( const int opNum, void Reduce3<X,Y>::exec(const int opNum,
void *vx, Nd4jLong *xShapeInfo, const void *vx, const Nd4jLong *xShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *vy, Nd4jLong *yShapeInfo, const void *vy, const Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength, int64_t start, int64_t stop) { int *dimension, int dimensionLength,
int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength, start, stop), REDUCE3_OPS); DISPATCH_BY_OPNUM_TT(exec, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength, start, stop), REDUCE3_OPS);
} }
@ -228,13 +232,14 @@ void Reduce3<X,Y>::exec( const int opNum,
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
void Reduce3<X,Y>::exec( const int opNum, void Reduce3<X,Y>::exec(const int opNum,
void *vx, Nd4jLong *xShapeInfo, const void *vx, const Nd4jLong *xShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *vy, Nd4jLong *yShapeInfo, const void *vy, const Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, int64_t start, int64_t stop) { const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(vx,xShapeInfo,extraParamsVals,vy, yShapeInfo,vz,zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), REDUCE3_OPS); DISPATCH_BY_OPNUM_TT(exec, PARAMS(vx,xShapeInfo,extraParamsVals,vy, yShapeInfo,vz,zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), REDUCE3_OPS);
} }
@ -243,13 +248,14 @@ void Reduce3<X,Y>::exec( const int opNum,
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
void Reduce3<X,Y>::execAll(const int opNum, void Reduce3<X,Y>::execAll(const int opNum,
void *vx, Nd4jLong *xShapeInfo, const void *vx, const Nd4jLong *xShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *vy, Nd4jLong *yShapeInfo, const void *vy, const Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, const Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets, int64_t start, int64_t stop) { const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets,
int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(execAll, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, start, stop), REDUCE3_OPS); DISPATCH_BY_OPNUM_TT(execAll, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, start, stop), REDUCE3_OPS);
} }

View File

@ -34,18 +34,18 @@ namespace scalar {
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z> template<typename X, typename Y, typename Z>
template<typename OpType> template<typename OpType>
void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo, void ScalarTransform<X, Y, Z>::transform(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams, void *vextraParams,
void *vz, Nd4jLong *zShapeInfo, void *vz, const Nd4jLong *zShapeInfo,
void *vscalars, const void *vscalars,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) { const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto scalars = reinterpret_cast<Y *>(vscalars); auto scalars = reinterpret_cast<const Y *>(vscalars);
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);
if (zTadShapeInfo == nullptr) { if (zTadShapeInfo == nullptr) {
@ -92,14 +92,14 @@ void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z> template<typename X, typename Y, typename Z>
void ScalarTransform<X,Y,Z>::transform(int opNum, void ScalarTransform<X,Y,Z>::transform(int opNum,
void *x, Nd4jLong *xShapeInfo, const void *x, const Nd4jLong *xShapeInfo,
void *extraParams, void *extraParams,
void *z, Nd4jLong *zShapeInfo, void *z, const Nd4jLong *zShapeInfo,
void *scalars, const void *scalars,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) { const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_OPS); DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_OPS);
} }
@ -107,12 +107,12 @@ void ScalarTransform<X,Y,Z>::transform(int opNum,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z> template<typename X, typename Y, typename Z>
void ScalarTransform<X, Y, Z>::transform(const int opNum, void ScalarTransform<X, Y, Z>::transform(const int opNum,
void *x, Nd4jLong xStride, const void *x, Nd4jLong xStride,
void *z, Nd4jLong zStride, void *z, Nd4jLong zStride,
void *scalar, const void *scalar,
void *extraParams, void *extraParams,
const uint64_t n, const uint64_t n,
const uint64_t start, const uint64_t stop) { const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xStride, z, zStride, scalar, extraParams, n, start, stop), SCALAR_OPS); DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xStride, z, zStride, scalar, extraParams, n, start, stop), SCALAR_OPS);
} }
@ -120,11 +120,11 @@ void ScalarTransform<X, Y, Z>::transform(const int opNum,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z> template<typename X, typename Y, typename Z>
void ScalarTransform<X, Y, Z>::transform(const int opNum, void ScalarTransform<X, Y, Z>::transform(const int opNum,
void *x, Nd4jLong *xShapeInfo, const void *x, const Nd4jLong *xShapeInfo,
void *z, Nd4jLong *zShapeInfo, void *z, const Nd4jLong *zShapeInfo,
void *scalar, const void *scalar,
void *extraParams, void *extraParams,
const uint64_t start, const uint64_t stop) { const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_OPS); DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_OPS);
} }
@ -132,15 +132,15 @@ void ScalarTransform<X, Y, Z>::transform(const int opNum,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z> template<typename X, typename Y, typename Z>
template<typename OpType> template<typename OpType>
void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo, void ScalarTransform<X, Y, Z>::transform(const void *vx, const Nd4jLong *xShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, const Nd4jLong *zShapeInfo,
void *vscalar, const void *vscalar,
void *vextraParams, void *vextraParams,
const uint64_t start, const uint64_t stop) { const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto scalar = reinterpret_cast<Y *>(vscalar)[0]; auto scalar = reinterpret_cast<const Y *>(vscalar)[0];
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);
const auto len = shape::length(xShapeInfo); const auto len = shape::length(xShapeInfo);
@ -181,15 +181,15 @@ void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z> template<typename X, typename Y, typename Z>
template<typename OpType> template<typename OpType>
void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong xEws, void ScalarTransform<X, Y, Z>::transform(const void *vx, Nd4jLong xEws,
void *vz, Nd4jLong zEws, void *vz, Nd4jLong zEws,
void *vscalar, const void *vscalar,
void *vextraParams, void *vextraParams,
const uint64_t len, const uint64_t start, const uint64_t stop) { const uint64_t len, const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto scalar = reinterpret_cast<Y *>(vscalar)[0]; auto scalar = reinterpret_cast<const Y *>(vscalar)[0];
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);
if (xEws == 1 && zEws == 1) { if (xEws == 1 && zEws == 1) {

View File

@ -34,18 +34,18 @@ namespace functions {
template<typename X, typename Z> template<typename X, typename Z>
template<typename OpType> template<typename OpType>
void ScalarBoolTransform<X, Z>::transform(void *vx, Nd4jLong *xShapeInfo, void ScalarBoolTransform<X, Z>::transform(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams, void *vextraParams,
void *vz, Nd4jLong *zShapeInfo, void *vz, const Nd4jLong *zShapeInfo,
void *vscalars, const void *vscalars,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) { const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto scalars = reinterpret_cast<X *>(vscalars); auto scalars = reinterpret_cast<const X *>(vscalars);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
if (zTadShapeInfo == nullptr) { if (zTadShapeInfo == nullptr) {
@ -92,60 +92,50 @@ namespace functions {
template<typename X, typename Y> template<typename X, typename Y>
void ScalarBoolTransform<X,Y>::transform(int opNum, void ScalarBoolTransform<X,Y>::transform(int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *extraParams, void *z, const Nd4jLong *zShapeInfo,
void *z, const void *scalars,
Nd4jLong *zShapeInfo, int *dimension, int dimensionLength,
void *scalars, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets,
int *dimension, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets,
int dimensionLength, const uint64_t start, const uint64_t stop) {
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffsets, const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_BOOL_OPS); DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_BOOL_OPS);
} }
template<typename X, typename Y> template<typename X, typename Y>
void ScalarBoolTransform<X, Y>::transform(const int opNum, void ScalarBoolTransform<X, Y>::transform(const int opNum,
void *x, const void *x, Nd4jLong xEws,
Nd4jLong xEws, void *z, Nd4jLong zEws,
void *z, const void *scalar,
Nd4jLong zEws, void *extraParams,
void *scalar, const uint64_t n,
void *extraParams, const uint64_t start, const uint64_t stop) {
const uint64_t n,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n, start, stop), SCALAR_BOOL_OPS); DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n, start, stop), SCALAR_BOOL_OPS);
} }
template<typename X, typename Y> template<typename X, typename Y>
void ScalarBoolTransform<X, Y>::transform(const int opNum, void ScalarBoolTransform<X, Y>::transform(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo,
void *z, const void *scalar,
Nd4jLong *zShapeInfo, void *extraParams,
void *scalar, const uint64_t start, const uint64_t stop) {
void *extraParams,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_BOOL_OPS); DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_BOOL_OPS);
} }
template<typename X, typename Z> template<typename X, typename Z>
template<typename OpType> template<typename OpType>
void ScalarBoolTransform<X, Z>::transform(void *vx, void ScalarBoolTransform<X, Z>::transform(const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo,
void *vz, const void *vscalar,
Nd4jLong *zShapeInfo, void *vextraParams,
void *vscalar, const uint64_t start, const uint64_t stop) {
void *vextraParams,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto scalar = reinterpret_cast<X *>(vscalar)[0]; auto scalar = reinterpret_cast<const X *>(vscalar)[0];
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
auto xEws = shape::elementWiseStride(xShapeInfo); auto xEws = shape::elementWiseStride(xShapeInfo);
@ -185,18 +175,16 @@ namespace functions {
template<typename X, typename Z> template<typename X, typename Z>
template<typename OpType> template<typename OpType>
void ScalarBoolTransform<X, Z>::transform(void *vx, void ScalarBoolTransform<X, Z>::transform(const void *vx, Nd4jLong xEws,
Nd4jLong xEws, void *vz, Nd4jLong zEws,
void *vz, const void *vscalar,
Nd4jLong zEws, void *vextraParams,
void *vscalar, const uint64_t len,
void *vextraParams, const uint64_t start, const uint64_t stop) {
const uint64_t len,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto scalar = reinterpret_cast<X *>(vscalar)[0]; auto scalar = reinterpret_cast<const X *>(vscalar)[0];
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
if (xEws == 1 && zEws == 1) { if (xEws == 1 && zEws == 1) {

View File

@ -34,18 +34,18 @@ namespace functions {
template<typename X> template<typename X>
template<typename OpType> template<typename OpType>
void ScalarIntTransform<X>::transform(void *vx, Nd4jLong *xShapeInfo, void ScalarIntTransform<X>::transform(const void *vx, const Nd4jLong *xShapeInfo,
void *vextraParams, void *vextraParams,
void *vz, Nd4jLong *zShapeInfo, void *vz, const Nd4jLong *zShapeInfo,
void *vscalars, const void *vscalars,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) { const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto scalars = reinterpret_cast<X *>(vscalars); auto scalars = reinterpret_cast<const X *>(vscalars);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
if (zTadShapeInfo == nullptr) { if (zTadShapeInfo == nullptr) {
@ -92,19 +92,14 @@ namespace functions {
template<typename X> template<typename X>
void ScalarIntTransform<X>::transform(int opNum, void ScalarIntTransform<X>::transform(int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *extraParams, void *z, const Nd4jLong *zShapeInfo,
void *z, const void *scalars,
Nd4jLong *zShapeInfo, int *dimension, int dimensionLength,
void *scalars, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets,
int *dimension, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets,
int dimensionLength, const uint64_t start, const uint64_t stop) {
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_INT_OPS); DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_INT_OPS);
} }
@ -112,42 +107,35 @@ namespace functions {
template<typename X> template<typename X>
void ScalarIntTransform<X>::transform(const int opNum, void ScalarIntTransform<X>::transform(const int opNum,
void *x, const void *x, Nd4jLong xEws,
Nd4jLong xEws, void *z, Nd4jLong zEws,
void *z, const void *scalar,
Nd4jLong zEws, void *extraParams,
void *scalar, const uint64_t n,
void *extraParams, const uint64_t start, const uint64_t stop) {
const uint64_t n,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n, start, stop), SCALAR_INT_OPS); DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n, start, stop), SCALAR_INT_OPS);
} }
template<typename X> template<typename X>
void ScalarIntTransform<X>::transform(const int opNum, void ScalarIntTransform<X>::transform(const int opNum,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo,
void *z, const void *scalar,
Nd4jLong *zShapeInfo, void *extraParams,
void *scalar, const uint64_t start, const uint64_t stop) {
void *extraParams,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_INT_OPS); DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_INT_OPS);
} }
template<typename X> template<typename X>
template<typename OpType> template<typename OpType>
void ScalarIntTransform<X>::transform(void *vx, void ScalarIntTransform<X>::transform(const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo,
void *vz, const void *vscalar, void *vextraParams,
Nd4jLong *zShapeInfo, const uint64_t start, const uint64_t stop) {
void *vscalar,
void *vextraParams,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto scalar = reinterpret_cast<X *>(vscalar)[0]; auto scalar = reinterpret_cast<const X *>(vscalar)[0];
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
auto xEws = shape::elementWiseStride(xShapeInfo); auto xEws = shape::elementWiseStride(xShapeInfo);
@ -187,18 +175,15 @@ namespace functions {
template<typename X> template<typename X>
template<typename OpType> template<typename OpType>
void ScalarIntTransform<X>::transform(void *vx, void ScalarIntTransform<X>::transform(const void *vx, Nd4jLong xEws,
Nd4jLong xEws, void *vz, Nd4jLong zEws,
void *vz, const void *vscalar,
Nd4jLong zEws, void *vextraParams,
void *vscalar, const uint64_t len, const uint64_t start, const uint64_t stop) {
void *vextraParams,
const uint64_t len,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto scalar = reinterpret_cast<X *>(vscalar)[0]; auto scalar = reinterpret_cast<const X *>(vscalar)[0];
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
if (scalar < (sizeof(X) * 8)) { if (scalar < (sizeof(X) * 8)) {

View File

@ -34,54 +34,46 @@ namespace functions {
template <typename X, typename Y> template <typename X, typename Y>
Y SummaryStatsReduce<X,Y>::execScalar(const int opNum, Y SummaryStatsReduce<X,Y>::execScalar(const int opNum,
const bool biasCorrected, const bool biasCorrected,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams) {
void *extraParams) {
RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams), SUMMARY_STATS_OPS); RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams), SUMMARY_STATS_OPS);
} }
template <typename X, typename Y> template <typename X, typename Y>
void SummaryStatsReduce<X,Y>::execScalar(const int opNum, void SummaryStatsReduce<X,Y>::execScalar(const int opNum,
const bool biasCorrected, const bool biasCorrected,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *extraParams, void *z, const Nd4jLong *zShapeInfo) {
void *z,
Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo), SUMMARY_STATS_OPS); DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo), SUMMARY_STATS_OPS);
} }
template <typename X, typename Y> template <typename X, typename Y>
void SummaryStatsReduce<X,Y>::exec(const int opNum, void SummaryStatsReduce<X,Y>::exec(const int opNum,
const bool biasCorrected, const bool biasCorrected,
void *x, const void *x, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *extraParams, void *z, const Nd4jLong *zShapeInfo,
void *z, int *dimension, int dimensionLength) {
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength), SUMMARY_STATS_OPS); DISPATCH_BY_OPNUM_TT(exec, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength), SUMMARY_STATS_OPS);
} }
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType > template <typename OpType >
void SummaryStatsReduce<X,Z>::execScalar(const bool biasCorrected, void SummaryStatsReduce<X,Z>::execScalar(const bool biasCorrected,
void *vx, const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *vextraParams,
void *vextraParams, void *vz, const Nd4jLong *zShapeInfo) {
void *vz,
Nd4jLong *zShapeInfo) {
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
z[0] = execScalar<OpType>(biasCorrected, vx, xShapeInfo, vextraParams); z[0] = execScalar<OpType>(biasCorrected, vx, xShapeInfo, vextraParams);
} }
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType > template <typename OpType >
Z SummaryStatsReduce<X,Z>::execScalar(const bool biasCorrected, void *vx, Nd4jLong *xShapeInfo, void *vextraParams) { Z SummaryStatsReduce<X,Z>::execScalar(const bool biasCorrected, const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);
SummaryStatsData<X> startingIndex; SummaryStatsData<X> startingIndex;
@ -105,15 +97,12 @@ namespace functions {
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType > template <typename OpType >
void SummaryStatsReduce<X,Z>::exec(const bool biasCorrected, void SummaryStatsReduce<X,Z>::exec(const bool biasCorrected,
void *vx, const void *vx, const Nd4jLong *xShapeInfo,
Nd4jLong *xShapeInfo, void *vextraParams,
void *vextraParams, void *vz, const Nd4jLong *zShapeInfo,
void *vz, int *dimension, int dimensionLength) {
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);
auto resultLength = shape::length(zShapeInfo); auto resultLength = shape::length(zShapeInfo);

View File

@ -30,25 +30,23 @@ namespace functions {
namespace transform { namespace transform {
template <typename X, typename Y> template <typename X, typename Y>
void TransformAny<X, Y>::exec( void TransformAny<X, Y>::exec(int opNum,
int opNum, const void *x, const Nd4jLong *xShapeInfo,
void *x, void *z, const Nd4jLong *zShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *z, uint64_t threadId, uint64_t numThreads) {
Nd4jLong *zShapeInfo,
void *extraParams,
uint64_t threadId, uint64_t numThreads) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_ANY_OPS); DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_ANY_OPS);
} }
///////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////
template <typename X, typename Z> template <typename X, typename Z>
template<typename OpType> template<typename OpType>
void _CUDA_H TransformAny<X, Z>::exec(void *vx, Nd4jLong *xShapeInfo, void _CUDA_H TransformAny<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vz,Nd4jLong *zShapeInfo, void *vz, const Nd4jLong *zShapeInfo,
void *vextraParams, uint64_t threadId, uint64_t numThreads) { void *vextraParams,
uint64_t threadId, uint64_t numThreads) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);

View File

@ -30,27 +30,22 @@ namespace functions {
namespace transform { namespace transform {
template <typename X, typename Y> template <typename X, typename Y>
void TransformBool<X, Y>::exec( void TransformBool<X, Y>::exec(int opNum,
int opNum, const void *x, const Nd4jLong *xShapeInfo,
void *x, void *z, const Nd4jLong *zShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *z, uint64_t threadId, uint64_t numThreads) {
Nd4jLong *zShapeInfo,
void *extraParams,
uint64_t threadId, uint64_t numThreads) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_BOOL_OPS); DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_BOOL_OPS);
} }
template <typename X, typename Z> template <typename X, typename Z>
template<typename OpType> template<typename OpType>
void _CUDA_H TransformBool<X, Z>::exec( void _CUDA_H TransformBool<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vx, void *vz, const Nd4jLong *zShapeInfo,
Nd4jLong *xShapeInfo, void *vextraParams,
void *vz, uint64_t threadId, uint64_t numThreads) {
Nd4jLong *zShapeInfo,
void *vextraParams, uint64_t threadId, uint64_t numThreads) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);

View File

@ -29,27 +29,22 @@ using namespace simdOps;
namespace functions { namespace functions {
namespace transform { namespace transform {
template <typename X, typename Y> template <typename X, typename Y>
void TransformFloat<X, Y>::exec( void TransformFloat<X, Y>::exec(int opNum,
int opNum, const void *x, const Nd4jLong *xShapeInfo,
void *x, void *z, const Nd4jLong *zShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *z, uint64_t threadId, uint64_t numThreads) {
Nd4jLong *zShapeInfo,
void *extraParams,
uint64_t threadId, uint64_t numThreads) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_FLOAT_OPS); DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_FLOAT_OPS);
} }
template <typename X, typename Z> template <typename X, typename Z>
template<typename OpType> template<typename OpType>
void _CUDA_H TransformFloat<X, Z>::exec( void _CUDA_H TransformFloat<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vx, void *vz, const Nd4jLong *zShapeInfo,
Nd4jLong *xShapeInfo, void *vextraParams,
void *vz, uint64_t threadId, uint64_t numThreads) {
Nd4jLong *zShapeInfo,
void *vextraParams, uint64_t threadId, uint64_t numThreads) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);

View File

@ -30,24 +30,22 @@ namespace functions {
namespace transform { namespace transform {
template <typename X> template <typename X>
void TransformSame<X>::exec( void TransformSame<X>::exec(int opNum,
int opNum, const void *x, const Nd4jLong *xShapeInfo,
void *x, void *z, const Nd4jLong *zShapeInfo,
Nd4jLong *xShapeInfo, void *extraParams,
void *z, uint64_t threadId, uint64_t numThreads) {
Nd4jLong *zShapeInfo,
void *extraParams, uint64_t threadId, uint64_t numThreads) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_SAME_OPS); DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_SAME_OPS);
} }
template <typename X> template <typename X>
template<typename OpType> template<typename OpType>
void _CUDA_H TransformSame<X>::exec(void *vx, Nd4jLong *xShapeInfo, void _CUDA_H TransformSame<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, const Nd4jLong *zShapeInfo,
void *vextraParams, void *vextraParams,
uint64_t threadId, uint64_t numThreads) { uint64_t threadId, uint64_t numThreads) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);

View File

@ -30,26 +30,23 @@ namespace functions {
namespace transform { namespace transform {
template <typename X> template <typename X>
void TransformStrict<X>::exec( void TransformStrict<X>::exec(int opNum,
int opNum, const void *x, const Nd4jLong *xShapeInfo,
void *x, void *z,
Nd4jLong *xShapeInfo, const Nd4jLong *zShapeInfo,
void *z, void *extraParams,
Nd4jLong *zShapeInfo, uint64_t threadId, uint64_t numThreads) {
void *extraParams, uint64_t threadId, uint64_t numThreads) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_STRICT_OPS); DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_STRICT_OPS);
} }
template <typename X> template <typename X>
template<typename OpType> template<typename OpType>
void _CUDA_H TransformStrict<X>::exec( void _CUDA_H TransformStrict<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
void *vx, void *vz, const Nd4jLong *zShapeInfo,
Nd4jLong *xShapeInfo, void *vextraParams,
void *vz, uint64_t threadId, uint64_t numThreads) {
Nd4jLong *zShapeInfo,
void *vextraParams, uint64_t threadId, uint64_t numThreads) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<const X *>(vx);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);

View File

@ -34,22 +34,22 @@ using namespace simdOps;
template<typename X, typename Y, typename Z, typename OpClass> template<typename X, typename Y, typename Z, typename OpClass>
static __global__ void broadcastSimple( static __global__ void broadcastSimple(
void *x, void const* x,
Nd4jLong *xShapeInfo, Nd4jLong const* xShapeInfo,
void *y, void const* y,
Nd4jLong *yShapeInfo, Nd4jLong const* yShapeInfo,
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong const* zShapeInfo,
int *dimension, int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
functions::broadcast::Broadcast<X,Y,Z>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); functions::broadcast::Broadcast<X,Y,Z>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
} }
template<typename X, typename Y, typename Z, typename OpClass> template<typename X, typename Y, typename Z, typename OpClass>
static __global__ void broadcastSimple(const void *x, const Nd4jLong *xShapeInfo, static __global__ void broadcastSimple(const void const* x, const Nd4jLong const* xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo, const void const* y, const Nd4jLong const* yShapeInfo,
void *z, const Nd4jLong *zShapeInfo ) { void *z, const Nd4jLong const* zShapeInfo ) {
functions::broadcast::Broadcast<X,Y,Z>::template transformCuda<OpClass>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); functions::broadcast::Broadcast<X,Y,Z>::template transformCuda<OpClass>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
} }
@ -57,14 +57,14 @@ static __global__ void broadcastSimple(const void *x, const Nd4jLong *xShapeInfo
template<typename X, typename Y, typename Z, typename OpClass> template<typename X, typename Y, typename Z, typename OpClass>
static __global__ void broadcastInverseSimple( static __global__ void broadcastInverseSimple(
void *x, void const* x,
Nd4jLong *xShapeInfo, Nd4jLong const* xShapeInfo,
void *y, void const* y,
Nd4jLong *yShapeInfo, Nd4jLong const* yShapeInfo,
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong const* zShapeInfo,
int *dimension, int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
functions::broadcast::Broadcast<X,Y,Z>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); functions::broadcast::Broadcast<X,Y,Z>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
} }
@ -73,17 +73,17 @@ static __global__ void broadcastInverseSimple(
namespace functions { namespace functions {
namespace broadcast { namespace broadcast {
static Nd4jLong __device__ __noinline__ getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) { static Nd4jLong __device__ __noinline__ getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo) {
return shape::getIndexOffset(index, shapeInfo); return shape::getIndexOffset(index, shapeInfo);
} }
static Nd4jLong __device__ __noinline__ length(Nd4jLong *shapeInfo) { static Nd4jLong __device__ __noinline__ length(const Nd4jLong *shapeInfo) {
return shape::length(shapeInfo); return shape::length(shapeInfo);
} }
template<typename X, typename Y, typename Z> template<typename X, typename Y, typename Z>
template <typename OpClass> template <typename OpClass>
__host__ void Broadcast<X,Y,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { __host__ void Broadcast<X,Y,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
broadcastSimple<X, Y, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); broadcastSimple<X, Y, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
} }
@ -94,14 +94,14 @@ namespace functions {
} }
template<typename X, typename Y, typename Z> template<typename X, typename Y, typename Z>
__host__ void Broadcast<X,Y,Z>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { __host__ void Broadcast<X,Y,Z>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
DISPATCH_BY_OPNUM_TTT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_OPS)) DISPATCH_BY_OPNUM_TTT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_OPS))
DEBUG_KERNEL(stream, opNum); DEBUG_KERNEL(stream, opNum);
} }
template<typename X, typename Y, typename Z> template<typename X, typename Y, typename Z>
__host__ void Broadcast<X,Y,Z>::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo) { __host__ void Broadcast<X,Y,Z>::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong const* zShapeInfo) {
DISPATCH_BY_OPNUM_TTT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), OPS_A(BROADCAST_OPS)) DISPATCH_BY_OPNUM_TTT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), OPS_A(BROADCAST_OPS))
DEBUG_KERNEL(stream, opNum); DEBUG_KERNEL(stream, opNum);
@ -109,12 +109,12 @@ namespace functions {
template<typename X, typename Y, typename Z> template<typename X, typename Y, typename Z>
template <typename OpClass> template <typename OpClass>
__host__ void Broadcast<X,Y,Z>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { __host__ void Broadcast<X,Y,Z>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
broadcastInverseSimple<X, Y, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); broadcastInverseSimple<X, Y, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
} }
template<typename X, typename Y, typename Z> template<typename X, typename Y, typename Z>
__host__ void Broadcast<X,Y,Z>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { __host__ void Broadcast<X,Y,Z>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
DISPATCH_BY_OPNUM_TTT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_OPS)) DISPATCH_BY_OPNUM_TTT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_OPS))
DEBUG_KERNEL(stream, opNum); DEBUG_KERNEL(stream, opNum);
@ -123,19 +123,19 @@ namespace functions {
template<typename X, typename Y, typename Z> template<typename X, typename Y, typename Z>
template <typename OpType> template <typename OpType>
__device__ void Broadcast<X,Y,Z>::transformInverseCuda( __device__ void Broadcast<X,Y,Z>::transformInverseCuda(
void *vx, Nd4jLong *xShapeInfo, void const* vx, Nd4jLong const* xShapeInfo,
void *vy, Nd4jLong *yShapeInfo, void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void* vz, Nd4jLong const* zShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
if (tadOnlyShapeInfoZ == nullptr) { if (tadOnlyShapeInfoZ == nullptr) {
tadOnlyShapeInfoZ = tadOnlyShapeInfo; tadOnlyShapeInfoZ = tadOnlyShapeInfo;
tadOffsetsZ = tadOffsets; tadOffsetsZ = tadOffsets;
} }
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X const*>(vx);
auto y = reinterpret_cast<Y*>(vy); auto y = reinterpret_cast<Y const*>(vy);
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
//decompose in to several sub tads after //decompose in to several sub tads after
@ -189,19 +189,19 @@ namespace functions {
template<typename X, typename Y, typename Z> template<typename X, typename Y, typename Z>
template <typename OpType> template <typename OpType>
__device__ void Broadcast<X,Y,Z>::transformCuda( __device__ void Broadcast<X,Y,Z>::transformCuda(
void *vx, Nd4jLong *xShapeInfo, void const* vx, Nd4jLong const* xShapeInfo,
void *vy, Nd4jLong *yShapeInfo, void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong const* zShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
if (tadOnlyShapeInfoZ == nullptr) { if (tadOnlyShapeInfoZ == nullptr) {
tadOnlyShapeInfoZ = tadOnlyShapeInfo; tadOnlyShapeInfoZ = tadOnlyShapeInfo;
tadOffsetsZ = tadOffsets; tadOffsetsZ = tadOffsets;
} }
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X const*>(vx);
auto y = reinterpret_cast<Y*>(vy); auto y = reinterpret_cast<Y const*>(vy);
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
//decompose in to several sub tads after //decompose in to several sub tads after

View File

@ -34,24 +34,24 @@ using namespace simdOps;
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X, typename Z, typename OpClass> template<typename X, typename Z, typename OpClass>
static __global__ void broadcastBoolSimple( static __global__ void broadcastBoolSimple(
void *x, void const* x,
Nd4jLong *xShapeInfo, Nd4jLong const* xShapeInfo,
void *y, void const* y,
Nd4jLong *yShapeInfo, Nd4jLong const* yShapeInfo,
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong const* zShapeInfo,
void *extraParams, void *extraParams,
int *dimension, int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
functions::broadcast::BroadcastBool<X, Z>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo, extraParams, dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); functions::broadcast::BroadcastBool<X, Z>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo, extraParams, dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X, typename Z, typename OpClass> template<typename X, typename Z, typename OpClass>
static __global__ void broadcastBoolSimple(const void *x, const Nd4jLong *xShapeInfo, static __global__ void broadcastBoolSimple(const void const* x, const Nd4jLong const* xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo, const void const* y, const Nd4jLong const* yShapeInfo,
void *z, const Nd4jLong *zShapeInfo, void *z, const Nd4jLong const* zShapeInfo,
void *extraParams) { void *extraParams) {
functions::broadcast::BroadcastBool<X, Z>::template transformCuda<OpClass>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); functions::broadcast::BroadcastBool<X, Z>::template transformCuda<OpClass>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
@ -59,15 +59,15 @@ static __global__ void broadcastBoolSimple(const void *x, const Nd4jLong *xShape
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X, typename Z, typename OpClass> template<typename X, typename Z, typename OpClass>
static __global__ void broadcastBoolInverseSimple( static __global__ void broadcastBoolInverseSimple(
void *x, void const* x,
Nd4jLong *xShapeInfo, Nd4jLong const* xShapeInfo,
void *y, void const* y,
Nd4jLong *yShapeInfo, Nd4jLong const* yShapeInfo,
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong const* zShapeInfo,
void *extraParams, void *extraParams,
int *dimension, int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
functions::broadcast::BroadcastBool<X, Z>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,extraParams,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); functions::broadcast::BroadcastBool<X, Z>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,extraParams,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
} }
@ -78,7 +78,7 @@ namespace broadcast {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpClass> template <typename OpClass>
__host__ void BroadcastBool<X,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { __host__ void BroadcastBool<X,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
broadcastBoolSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); broadcastBoolSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
sd::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed"); sd::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed");
} }
@ -98,7 +98,7 @@ __host__ void BroadcastBool<X,Z>::intermediateBroadcast(dim3 launchDims, cudaStr
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X, typename Y> template<typename X, typename Y>
__host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { __host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS)) DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS))
DEBUG_KERNEL(stream, opNum); DEBUG_KERNEL(stream, opNum);
@ -119,14 +119,14 @@ __host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *s
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpClass> template <typename OpClass>
__host__ void BroadcastBool<X,Z>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { __host__ void BroadcastBool<X,Z>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
broadcastBoolInverseSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); broadcastBoolInverseSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
sd::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed"); sd::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed");
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X, typename Y> template<typename X, typename Y>
__host__ void BroadcastBool<X,Y>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { __host__ void BroadcastBool<X,Y>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
DISPATCH_BY_OPNUM_TT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS)) DISPATCH_BY_OPNUM_TT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS))
DEBUG_KERNEL(stream, opNum); DEBUG_KERNEL(stream, opNum);
@ -136,20 +136,20 @@ __host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *s
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
__device__ void BroadcastBool<X,Z>::transformInverseCuda( __device__ void BroadcastBool<X,Z>::transformInverseCuda(
void *vx, Nd4jLong *xShapeInfo, void const* vx, Nd4jLong const* xShapeInfo,
void *vy, Nd4jLong *yShapeInfo, void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong const* zShapeInfo,
void *vextraParams, void *vextraParams,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
if (tadOnlyShapeInfoZ == nullptr) { if (tadOnlyShapeInfoZ == nullptr) {
tadOnlyShapeInfoZ = tadOnlyShapeInfo; tadOnlyShapeInfoZ = tadOnlyShapeInfo;
tadOffsetsZ = tadOffsets; tadOffsetsZ = tadOffsets;
} }
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X const*>(vx);
auto y = reinterpret_cast<X*>(vy); auto y = reinterpret_cast<X const*>(vy);
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams); auto extraParams = reinterpret_cast<X*>(vextraParams);
@ -198,20 +198,20 @@ __host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *s
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
__device__ void BroadcastBool<X,Z>::transformCuda( __device__ void BroadcastBool<X,Z>::transformCuda(
void *vx, Nd4jLong *xShapeInfo, void const* vx, Nd4jLong const* xShapeInfo,
void *vy, Nd4jLong *yShapeInfo, void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong const* zShapeInfo,
void *vextraParams, void *vextraParams,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
if (tadOnlyShapeInfoZ == nullptr) { if (tadOnlyShapeInfoZ == nullptr) {
tadOnlyShapeInfoZ = tadOnlyShapeInfo; tadOnlyShapeInfoZ = tadOnlyShapeInfo;
tadOffsetsZ = tadOffsets; tadOffsetsZ = tadOffsets;
} }
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X const*>(vx);
auto y = reinterpret_cast<X*>(vy); auto y = reinterpret_cast<X const*>(vy);
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams); auto extraParams = reinterpret_cast<X*>(vextraParams);
@ -235,7 +235,7 @@ __host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *s
__syncthreads(); __syncthreads();
__shared__ Z *rZ; __shared__ Z *rZ;
__shared__ X *rX; __shared__ X const* rX;
for (int r = blockIdx.x; r < numTads; r += gridDim.x) { for (int r = blockIdx.x; r < numTads; r += gridDim.x) {

View File

@ -34,23 +34,23 @@ using namespace simdOps;
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X, typename OpClass> template<typename X, typename OpClass>
static __global__ void broadcastIntSimple( static __global__ void broadcastIntSimple(
void *x, void const* x,
Nd4jLong *xShapeInfo, Nd4jLong const* xShapeInfo,
void *y, void const* y,
Nd4jLong *yShapeInfo, Nd4jLong const* yShapeInfo,
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong const* zShapeInfo,
int *dimension, int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
functions::broadcast::BroadcastInt<X>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); functions::broadcast::BroadcastInt<X>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X, typename OpClass> template<typename X, typename OpClass>
static __global__ void broadcastIntSimple(const void *x, const Nd4jLong *xShapeInfo, static __global__ void broadcastIntSimple(const void *x, const Nd4jLong const* xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo, const void *y, const Nd4jLong const* yShapeInfo,
void *z, const Nd4jLong *zShapeInfo) { void *z, const Nd4jLong const* zShapeInfo) {
functions::broadcast::BroadcastInt<X>::template transformCuda<OpClass>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); functions::broadcast::BroadcastInt<X>::template transformCuda<OpClass>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
} }
@ -58,14 +58,14 @@ static __global__ void broadcastIntSimple(const void *x, const Nd4jLong *xShapeI
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X, typename OpClass> template<typename X, typename OpClass>
static __global__ void broadcastBoolInverseSimple( static __global__ void broadcastBoolInverseSimple(
void *x, void const* x,
Nd4jLong *xShapeInfo, Nd4jLong const* xShapeInfo,
void *y, void const* y,
Nd4jLong *yShapeInfo, Nd4jLong const* yShapeInfo,
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong const* zShapeInfo,
int *dimension, int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
functions::broadcast::BroadcastInt<X>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); functions::broadcast::BroadcastInt<X>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
} }
@ -75,7 +75,7 @@ namespace broadcast {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X> template<typename X>
template <typename OpClass> template <typename OpClass>
__host__ void BroadcastInt<X>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { __host__ void BroadcastInt<X>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
broadcastIntSimple<X, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); broadcastIntSimple<X, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
} }
@ -92,16 +92,16 @@ __host__ void BroadcastInt<X>::intermediateBroadcast(dim3 launchDims, cudaStream
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X> template<typename X>
__host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { __host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS)) DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS))
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X> template<typename X>
__host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, __host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum,
const void *x, const Nd4jLong *xShapeInfo, const void *x, const Nd4jLong const* xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo, const void *y, const Nd4jLong const* yShapeInfo,
void *z, const Nd4jLong *zShapeInfo) { void *z, const Nd4jLong const* zShapeInfo) {
DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), OPS_A(BROADCAST_INT_OPS)) DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), OPS_A(BROADCAST_INT_OPS))
} }
@ -109,13 +109,13 @@ __host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stre
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X> template<typename X>
template <typename OpClass> template <typename OpClass>
__host__ void BroadcastInt<X>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { __host__ void BroadcastInt<X>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
broadcastBoolInverseSimple<X, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); broadcastBoolInverseSimple<X, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X> template<typename X>
__host__ void BroadcastInt<X>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { __host__ void BroadcastInt<X>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
DISPATCH_BY_OPNUM_T(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS)) DISPATCH_BY_OPNUM_T(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS))
} }
@ -123,19 +123,19 @@ __host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stre
template<typename X> template<typename X>
template <typename OpType> template <typename OpType>
__device__ void BroadcastInt<X>::transformInverseCuda( __device__ void BroadcastInt<X>::transformInverseCuda(
void *vx, Nd4jLong *xShapeInfo, void const* vx, Nd4jLong const* xShapeInfo,
void *vy, Nd4jLong *yShapeInfo, void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong const* zShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
if (tadOnlyShapeInfoZ == nullptr) { if (tadOnlyShapeInfoZ == nullptr) {
tadOnlyShapeInfoZ = tadOnlyShapeInfo; tadOnlyShapeInfoZ = tadOnlyShapeInfo;
tadOffsetsZ = tadOffsets; tadOffsetsZ = tadOffsets;
} }
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X const*>(vx);
auto y = reinterpret_cast<X*>(vy); auto y = reinterpret_cast<X const*>(vy);
auto z = reinterpret_cast<X*>(vz); auto z = reinterpret_cast<X*>(vz);
//decompose in to several sub tads after //decompose in to several sub tads after
@ -183,19 +183,19 @@ __host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stre
template<typename X> template<typename X>
template <typename OpType> template <typename OpType>
__device__ void BroadcastInt<X>::transformCuda( __device__ void BroadcastInt<X>::transformCuda(
void *vx, Nd4jLong *xShapeInfo, void const* vx, Nd4jLong const* xShapeInfo,
void *vy, Nd4jLong *yShapeInfo, void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong const* zShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
if (tadOnlyShapeInfoZ == nullptr) { if (tadOnlyShapeInfoZ == nullptr) {
tadOnlyShapeInfoZ = tadOnlyShapeInfo; tadOnlyShapeInfoZ = tadOnlyShapeInfo;
tadOffsetsZ = tadOffsets; tadOffsetsZ = tadOffsets;
} }
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X const*>(vx);
auto y = reinterpret_cast<X*>(vy); auto y = reinterpret_cast<X const*>(vy);
auto z = reinterpret_cast<X*>(vz); auto z = reinterpret_cast<X*>(vz);
//decompose in to several sub tads after //decompose in to several sub tads after
@ -218,7 +218,7 @@ __host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stre
__syncthreads(); __syncthreads();
__shared__ X *rZ; __shared__ X *rZ;
__shared__ X *rX; __shared__ X const* rX;
for (int r = blockIdx.x; r < numTads; r += gridDim.x) { for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
@ -250,9 +250,9 @@ __host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stre
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X> template<typename X>
template <typename OpType> template <typename OpType>
__device__ void BroadcastInt<X>::transformCuda(const void *vx, const Nd4jLong *xShapeInfo, __device__ void BroadcastInt<X>::transformCuda(const void *vx, const Nd4jLong const* xShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo, const void *vy, const Nd4jLong const* yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo) { void *vz, const Nd4jLong const* zShapeInfo) {
const X* x = reinterpret_cast<const X*>(vx); const X* x = reinterpret_cast<const X*>(vx);
const X* y = reinterpret_cast<const X*>(vy); const X* y = reinterpret_cast<const X*>(vy);

View File

@ -31,14 +31,14 @@ using namespace simdOps;
template <typename X, typename Z> template <typename X, typename Z>
static __global__ void simpleIndexReduceGeneric(const int op, static __global__ void simpleIndexReduceGeneric(const int op,
void *dx, void const* dx,
Nd4jLong *xShapeInfo, int xRank, Nd4jLong const* xShapeInfo, int xRank,
void *extraParams, void *extraParams,
void *result, void *result,
Nd4jLong *zShapeInfo, int zRank, Nd4jLong const* zShapeInfo, int zRank,
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) {
functions::indexreduce::IndexReduce<X, Z>::transform(op,dx,xShapeInfo,extraParams,result,zShapeInfo,dimension,dimensionLength,postProcessOrNot,allocationBuffer,reductionBuffer,tadOnlyShapeInfo,tadOffsets); functions::indexreduce::IndexReduce<X, Z>::transform(op,dx,xShapeInfo,extraParams,result,zShapeInfo,dimension,dimensionLength,postProcessOrNot,allocationBuffer,reductionBuffer,tadOnlyShapeInfo,tadOffsets);
} }
@ -49,15 +49,15 @@ namespace functions {
template <typename X, typename Z> template <typename X, typename Z>
_CUDA_H void IndexReduce<X,Z>::executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream, _CUDA_H void IndexReduce<X,Z>::executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream,
const int opNum, const int opNum,
void *dx, Nd4jLong *xShapeInfo, void const* dx, Nd4jLong const* xShapeInfo,
int xRank, int xRank,
void *extraParams, void *extraParams,
void *result, Nd4jLong *zShapeInfo, void *result, Nd4jLong const* zShapeInfo,
int zRank, int zRank,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
int postProcessOrNot, int postProcessOrNot,
int *allocationBuffer, void *reductionBuffer, int *allocationBuffer, void *reductionBuffer,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) {
simpleIndexReduceGeneric<X, Z><<<launchDims.x,launchDims.y,launchDims.z, *stream>>>(opNum, simpleIndexReduceGeneric<X, Z><<<launchDims.x,launchDims.y,launchDims.z, *stream>>>(opNum,
dx, xShapeInfo, xRank, dx, xShapeInfo, xRank,
@ -70,7 +70,7 @@ namespace functions {
} }
template <typename X, typename Z> template <typename X, typename Z>
_CUDA_H void IndexReduce<X, Z>::executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int opNum, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong *zShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { _CUDA_H void IndexReduce<X, Z>::executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int opNum, void const* dx, Nd4jLong const* xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong const* zShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) {
simpleIndexReduceGeneric<X, Z><<<launchDims.x,launchDims.y,launchDims.z, *stream>>>( simpleIndexReduceGeneric<X, Z><<<launchDims.x,launchDims.y,launchDims.z, *stream>>>(
opNum, opNum,
dx, dx,
@ -154,35 +154,35 @@ namespace functions {
template <typename X, typename Y> template <typename X, typename Y>
__device__ void IndexReduce<X, Y>::transform( __device__ void IndexReduce<X, Y>::transform(
const int opNum, const int opNum,
void *x, void const* x,
Nd4jLong *xShapeInfo, Nd4jLong const* xShapeInfo,
void *extraParams, void *extraParams,
void *result, void *result,
Nd4jLong *zShapeInfo, Nd4jLong const* zShapeInfo,
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
int postProcessOrNot, int postProcessOrNot,
int *allocationBuffer, int *allocationBuffer,
void *reductionBuffer, void *reductionBuffer,
Nd4jLong *tadShapeInfo, Nd4jLong const* tadShapeInfo,
Nd4jLong *tadOffset) { Nd4jLong const* tadOffset) {
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, result, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationBuffer, reductionBuffer, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS); DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, result, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationBuffer, reductionBuffer, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS);
} }
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType> template <typename OpType>
__device__ void IndexReduce<X, Z>::transform(void *vdx, Nd4jLong *xShapeInfo, __device__ void IndexReduce<X, Z>::transform(void const* vdx, Nd4jLong const* xShapeInfo,
void *vextraParams, void *vextraParams,
void *vz, Nd4jLong *zShapeInfo, void* vz, Nd4jLong const* zShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
int postProcessOrNot, int postProcessOrNot,
int *allocationBuffer, void *vreductionBuffer, int *allocationBuffer, void *vreductionBuffer,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets){ Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets){
/**int /**int
* Gpu information for the problem * Gpu information for the problem
*/ */
auto dx = reinterpret_cast<X*>(vdx); auto dx = reinterpret_cast<X const*>(vdx);
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
auto extraParams = static_cast<X*>(vextraParams); auto extraParams = static_cast<X*>(vextraParams);
auto reductionBuffer = static_cast<X*>(vreductionBuffer); auto reductionBuffer = static_cast<X*>(vreductionBuffer);

View File

@ -28,13 +28,13 @@ using namespace simdOps;
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z, typename OpType> template <typename X, typename Y, typename Z, typename OpType>
__global__ static void pairwiseSimpleShaped(void* vx, Nd4jLong *xShapeInfo, __global__ static void pairwiseSimpleShaped(void const* vx, Nd4jLong const* xShapeInfo,
void *vy, Nd4jLong *yShapeInfo, void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong const* zShapeInfo,
void *vextraParams) { void *vextraParams) {
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X const*>(vx);
auto y = reinterpret_cast<Y*>(vy); auto y = reinterpret_cast<Y const*>(vy);
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
auto extraParams = reinterpret_cast<Z*>(vextraParams); auto extraParams = reinterpret_cast<Z*>(vextraParams);
@ -91,9 +91,9 @@ namespace pairwise_transforms {
template<typename X, typename Y, typename Z> template<typename X, typename Y, typename Z>
template<typename OpType> template<typename OpType>
void __host__ PairWiseTransform<X,Y,Z>::intermediateShaped(dim3& launchDims, cudaStream_t *stream, void __host__ PairWiseTransform<X,Y,Z>::intermediateShaped(dim3& launchDims, cudaStream_t *stream,
void *vx, Nd4jLong *xShapeInfo, void const* vx, Nd4jLong const* xShapeInfo,
void *vy, Nd4jLong *yShapeInfo, void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong const* zShapeInfo,
void *vextraParams){ void *vextraParams){
pairwiseSimpleShaped<X, Y, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams); pairwiseSimpleShaped<X, Y, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams);
@ -101,7 +101,7 @@ void __host__ PairWiseTransform<X,Y,Z>::intermediateShaped(dim3& launchDims, cud
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z> template<typename X, typename Y, typename Z>
void __host__ PairWiseTransform<X,Y,Z>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vextraParams) { void __host__ PairWiseTransform<X,Y,Z>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, void *vz, Nd4jLong const* zShapeInfo, void* vextraParams) {
DISPATCH_BY_OPNUM_TTT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_TRANSFORM_OPS); DISPATCH_BY_OPNUM_TTT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_TRANSFORM_OPS);
} }

View File

@ -28,13 +28,13 @@ using namespace simdOps;
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
template <typename X, typename Z, typename OpType> template <typename X, typename Z, typename OpType>
__global__ static void pairwiseSimpleShaped(void* vx, Nd4jLong *xShapeInfo, __global__ static void pairwiseSimpleShaped(void const* vx, Nd4jLong const* xShapeInfo,
void *vy, Nd4jLong *yShapeInfo, void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong const* zShapeInfo,
void *vextraParams) { void *vextraParams) {
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X const*>(vx);
auto y = reinterpret_cast<X*>(vy); auto y = reinterpret_cast<X const*>(vy);
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams); auto extraParams = reinterpret_cast<X*>(vextraParams);
@ -92,9 +92,9 @@ namespace pairwise_transforms {
template<typename X, typename Z> template<typename X, typename Z>
template<typename OpType> template<typename OpType>
void _CUDA_H PairWiseBoolTransform<X,Z>::intermediateShaped(dim3& launchDims, cudaStream_t *stream, void _CUDA_H PairWiseBoolTransform<X,Z>::intermediateShaped(dim3& launchDims, cudaStream_t *stream,
void *vx, Nd4jLong *xShapeInfo, void const* vx, Nd4jLong const* xShapeInfo,
void *vy, Nd4jLong *yShapeInfo, void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong const* zShapeInfo,
void *vextraParams){ void *vextraParams){
pairwiseSimpleShaped<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams); pairwiseSimpleShaped<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams);
@ -103,7 +103,7 @@ void _CUDA_H PairWiseBoolTransform<X,Z>::intermediateShaped(dim3& launchDims, cu
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
template<typename X, typename Y> template<typename X, typename Y>
void PairWiseBoolTransform<X,Y>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vextraParams) { void PairWiseBoolTransform<X,Y>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, void *vz, Nd4jLong const* zShapeInfo, void *vextraParams) {
auto xType = sd::DataTypeUtils::fromT<X>(); auto xType = sd::DataTypeUtils::fromT<X>();
auto yType = sd::DataTypeUtils::fromT<Y>(); auto yType = sd::DataTypeUtils::fromT<Y>();

View File

@ -28,13 +28,13 @@ using namespace simdOps;
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType> template <typename X, typename OpType>
__global__ static void pairwiseSimpleShaped(void* vx, Nd4jLong *xShapeInfo, __global__ static void pairwiseSimpleShaped(void const* vx, Nd4jLong const* xShapeInfo,
void *vy, Nd4jLong *yShapeInfo, void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong const* zShapeInfo,
void *vextraParams) { void *vextraParams) {
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X const*>(vx);
auto y = reinterpret_cast<X*>(vy); auto y = reinterpret_cast<X const*>(vy);
auto z = reinterpret_cast<X*>(vz); auto z = reinterpret_cast<X*>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams); auto extraParams = reinterpret_cast<X*>(vextraParams);
@ -92,9 +92,9 @@ namespace pairwise_transforms {
template<typename X> template<typename X>
template<typename OpType> template<typename OpType>
void _CUDA_H PairWiseIntTransform<X>::intermediateShaped(dim3& launchDims, cudaStream_t *stream, void _CUDA_H PairWiseIntTransform<X>::intermediateShaped(dim3& launchDims, cudaStream_t *stream,
void *vx, Nd4jLong *xShapeInfo, void const* vx, Nd4jLong const* xShapeInfo,
void *vy, Nd4jLong *yShapeInfo, void const* vy, Nd4jLong const* yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong const* zShapeInfo,
void *vextraParams){ void *vextraParams){
pairwiseSimpleShaped<X, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams); pairwiseSimpleShaped<X, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams);
@ -103,7 +103,7 @@ void _CUDA_H PairWiseIntTransform<X>::intermediateShaped(dim3& launchDims, cudaS
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
template<typename X> template<typename X>
void PairWiseIntTransform<X>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vextraParams) { void PairWiseIntTransform<X>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, void *vz, Nd4jLong const* zShapeInfo, void *vextraParams) {
auto xType = sd::DataTypeUtils::fromT<X>(); auto xType = sd::DataTypeUtils::fromT<X>();
DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_INT_OPS); DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_INT_OPS);

Some files were not shown because too many files have changed in this diff Show More