444 lines
21 KiB
Plaintext
444 lines
21 KiB
Plaintext
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
|
|
#include <op_boilerplate.h>
|
|
#include <pointercast.h>
|
|
#include <helpers/TAD.h>
|
|
#include <types/float16.h>
|
|
#include <loops/grid_shaped.legacy>
|
|
#include <helpers/DebugHelper.h>
|
|
|
|
|
|
#include <ops/meta_ops.h>
|
|
#include <loops/legacy_ops.h>
|
|
|
|
|
|
#define GRID_WIDTH 19 // number of pointers within single grid row
|
|
|
|
template <typename T>
|
|
__device__ inline static void metaPredicateShapeGeneric(const int opTypeA, const int opNumA, const int opTypeB, const int opNumB,
|
|
Nd4jLong N, void *vdx, Nd4jLong *xShapeInfo, void *vdy, Nd4jLong *yShapeInfo, void *vdz, Nd4jLong *zShapeInfo,
|
|
void *vextraA, void *vextraB, double scalarA, double scalarB) {
|
|
|
|
auto dx = static_cast<T*>(vdx);
|
|
auto dy = static_cast<T*>(vdy);
|
|
auto dz = static_cast<T*>(vdz);
|
|
auto extraA = static_cast<T*>(vextraA);
|
|
auto extraB = static_cast<T*>(vextraB);
|
|
|
|
__shared__ Nd4jPointer params[2];
|
|
__shared__ T *paramsPtr;
|
|
if (threadIdx.x == 0) {
|
|
if (opTypeA == 0) {
|
|
params[0] = reinterpret_cast<Nd4jPointer *>(&scalarA);
|
|
}
|
|
else params[0] = reinterpret_cast<Nd4jPointer *>(extraA);
|
|
|
|
if (opTypeB == 0) {
|
|
params[1] = reinterpret_cast<Nd4jPointer *>(&scalarB);
|
|
}
|
|
else params[1] = reinterpret_cast<Nd4jPointer *>(extraB);
|
|
|
|
paramsPtr = reinterpret_cast<T *>(params);
|
|
}
|
|
__syncthreads();
|
|
|
|
if (opTypeA == 2) {
|
|
if (opTypeB == 0) {
|
|
// DISPATCH_METAOP(functions::pairwise_transforms::PairWiseTransform<T>::template transformCuda, PARAMS(dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, paramsPtr, nullptr, nullptr, nullptr), InvertedMetaOp, OPS_A(PAIRWISE_TRANSFORM_OPS), OPS_B(SCALAR_OPS));
|
|
// functions::pairwise_transforms::PairWiseTransform<T>::template transformCuda<simdOps::InvertedMetaOp<T, simdOps::Copy<T>, simdOps::Multiply<T>>>(dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, paramsPtr, nullptr, nullptr, nullptr);
|
|
}
|
|
}
|
|
}
|
|
|
|
template<typename T, typename OpClass>
|
|
__device__ static inline void invertedMetaPairwiseShapedGeneric(const int opTypeA, const int opTypeB, Nd4jLong N, void *vdx,
|
|
Nd4jLong *xShapeInfo, void *vdy, Nd4jLong *yShapeInfo, void *vdz, Nd4jLong *zShapeInfo,
|
|
void *vextraA, void *vextraB, double scalarA, double scalarB) {
|
|
|
|
auto dx = static_cast<T*>(vdx);
|
|
auto dy = static_cast<T*>(vdy);
|
|
auto dz = static_cast<T*>(vdz);
|
|
auto extraA = static_cast<T*>(vextraA);
|
|
auto extraB = static_cast<T*>(vextraB);
|
|
|
|
__shared__ Nd4jPointer params[2];
|
|
__shared__ T *paramsPtr;
|
|
if (threadIdx.x == 0) {
|
|
if (opTypeA == 0) {
|
|
params[0] = reinterpret_cast<Nd4jPointer *>(&scalarA);
|
|
}
|
|
else params[0] = reinterpret_cast<Nd4jPointer *>(extraA);
|
|
|
|
if (opTypeB == 0) {
|
|
params[1] = reinterpret_cast<Nd4jPointer *>(&scalarB);
|
|
}
|
|
else params[1] = reinterpret_cast<Nd4jPointer *>(extraB);
|
|
|
|
paramsPtr = reinterpret_cast<T *>(params);
|
|
}
|
|
__syncthreads();
|
|
|
|
functions::grid::GRIDShaped<T>::template transformCuda<OpClass>(dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, paramsPtr, nullptr, nullptr, nullptr);
|
|
};
|
|
|
|
template<typename T, typename OpClass>
|
|
__device__ static inline void invertedMetaPairwiseShapedGeneric(const int opTypeA, const int opNumA, const int opTypeB, const int opNumB,
|
|
Nd4jLong N, void *vdx, Nd4jLong *xShapeInfo, void *vdy, Nd4jLong *yShapeInfo, void *vdz,
|
|
Nd4jLong *zShapeInfo, void *vextraA, void *vextraB, double scalarA, double scalarB) {
|
|
|
|
auto dx = static_cast<T*>(vdx);
|
|
auto dy = static_cast<T*>(vdy);
|
|
auto dz = static_cast<T*>(vdz);
|
|
auto extraA = static_cast<T*>(vextraA);
|
|
auto extraB = static_cast<T*>(vextraB);
|
|
|
|
__shared__ Nd4jPointer params[2];
|
|
__shared__ T *paramsPtr;
|
|
if (threadIdx.x == 0) {
|
|
if (opTypeA == 0) {
|
|
params[0] = reinterpret_cast<Nd4jPointer *>(&scalarA);
|
|
}
|
|
else params[0] = reinterpret_cast<Nd4jPointer *>(extraA);
|
|
|
|
if (opTypeB == 0) {
|
|
params[1] = reinterpret_cast<Nd4jPointer *>(&scalarB);
|
|
}
|
|
else params[1] = reinterpret_cast<Nd4jPointer *>(extraB);
|
|
|
|
paramsPtr = reinterpret_cast<T *>(params);
|
|
}
|
|
__syncthreads();
|
|
|
|
functions::grid::GRIDShaped<T>::template transformCuda<OpClass>(opTypeA, opNumA, opTypeB, opNumB, dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, paramsPtr, nullptr, nullptr, nullptr);
|
|
};
|
|
|
|
template<typename T>
|
|
__device__ static inline void invertedMetaPairwiseShapedNumericGeneric(const int opTypeA, const int opNumA, const int opTypeB, const int opNumB,
|
|
Nd4jLong N, void *vdx, Nd4jLong *xShapeInfo, void *vdy, Nd4jLong *yShapeInfo, void *vdz,
|
|
Nd4jLong *zShapeInfo, void *vextraA, void *vextraB, double scalarA, double scalarB) {
|
|
|
|
auto dx = static_cast<T*>(vdx);
|
|
auto dy = static_cast<T*>(vdy);
|
|
auto dz = static_cast<T*>(vdz);
|
|
auto extraA = static_cast<T*>(vextraA);
|
|
auto extraB = static_cast<T*>(vextraB);
|
|
|
|
__shared__ Nd4jPointer params[2];
|
|
__shared__ T *paramsPtr;
|
|
if (threadIdx.x == 0) {
|
|
if (opTypeA == 0) {
|
|
params[0] = reinterpret_cast<Nd4jPointer *>(&scalarA);
|
|
}
|
|
else params[0] = reinterpret_cast<Nd4jPointer *>(extraA);
|
|
|
|
if (opTypeB == 0) {
|
|
params[1] = reinterpret_cast<Nd4jPointer *>(&scalarB);
|
|
}
|
|
else params[1] = reinterpret_cast<Nd4jPointer *>(extraB);
|
|
|
|
paramsPtr = reinterpret_cast<T *>(params);
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
functions::grid::GRIDShaped<T>::transformCuda(opTypeA, opNumA, opTypeB, opNumB, dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, paramsPtr, nullptr, nullptr, nullptr);
|
|
};
|
|
|
|
|
|
extern "C" __global__ void invertedMetaPairwiseShapedNumericFloat(const int opTypeA, const int opNumA, const int opTypeB, const int opNumB, Nd4jLong N, void *dx, Nd4jLong *xShapeInfo, void *dy, Nd4jLong *yShapeInfo, void *dz, Nd4jLong *zShapeInfo, void *extraA, void *extraB, double scalarA, double scalarB) {
|
|
invertedMetaPairwiseShapedNumericGeneric<float>(opTypeA, opNumA, opTypeB, opNumB, N, dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, extraA, extraB, scalarA, scalarB);
|
|
}
|
|
|
|
extern "C" __global__ void invertedMetaPairwiseShapedNumericDouble(const int opTypeA, const int opNumA, const int opTypeB, const int opNumB, Nd4jLong N, void *dx, Nd4jLong *xShapeInfo, void *dy, Nd4jLong *yShapeInfo, void *dz, Nd4jLong *zShapeInfo, void *extraA, void *extraB, double scalarA, double scalarB) {
|
|
invertedMetaPairwiseShapedNumericGeneric<double>(opTypeA, opNumA, opTypeB, opNumB, N, dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, extraA, extraB, scalarA, scalarB);
|
|
}
|
|
|
|
extern "C" __global__ void invertedMetaPairwiseShapedNumericHalf(const int opTypeA, const int opNumA, const int opTypeB, const int opNumB, Nd4jLong N, void *dx, Nd4jLong *xShapeInfo, void *dy, Nd4jLong *yShapeInfo, void *dz, Nd4jLong *zShapeInfo, void *extraA, void *extraB, double scalarA, double scalarB) {
|
|
invertedMetaPairwiseShapedNumericGeneric<float16>(opTypeA, opNumA, opTypeB, opNumB, N, dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, extraA, extraB, scalarA, scalarB);
|
|
}
|
|
|
|
|
|
|
|
#ifndef __CLION_IDE__
|
|
// kernels set for pairwise + scalar based on shape
|
|
//DISPATCH_KERNEL_META(invertedMetaPairwiseShaped_Pairwise_Scalar_, invertedMetaPairwiseShapedGeneric, float, metaOps::InvertedMetaOp, INPUT(const int opTypeA, const int opTypeB, Nd4jLong N, float *dx, int *xShapeInfo, float *dy, int *yShapeInfo, float *dz, int *zShapeInfo, float *extraA, float *extraB, float scalarA, float scalarB), PARAMS(opTypeA, opTypeB, N, dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, extraA, extraB, scalarA, scalarB), OPS_A(PAIRWISE_TRANSFORM_OPS), OPS_B(SCALAR_OPS))
|
|
//DISPATCH_KERNEL_META(invertedMetaPairwiseShaped_Pairwise_Scalar_, invertedMetaPairwiseShapedGeneric, double, metaOps::InvertedMetaOp, INPUT(const int opTypeA, const int opTypeB, Nd4jLong N, double *dx, int *xShapeInfo, double *dy, int *yShapeInfo, double *dz, int *zShapeInfo, double *extraA, double *extraB, double scalarA, double scalarB), PARAMS(opTypeA, opTypeB, N, dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, extraA, extraB, scalarA, scalarB), OPS_A(PAIRWISE_TRANSFORM_OPS), OPS_B(SCALAR_OPS))
|
|
//DISPATCH_KERNEL_META(invertedMetaPairwiseShaped_Pairwise_Scalar_, invertedMetaPairwiseShapedGeneric, float16, metaOps::InvertedMetaOp, INPUT(const int opTypeA, const int opTypeB, Nd4jLong N, float16 *dx, int *xShapeInfo, float16 *dy, int *yShapeInfo, float16 *dz, int *zShapeInfo, float16 *extraA, float16 *extraB, float16 scalarA, float16 scalarB), PARAMS(opTypeA, opTypeB, N, dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, extraA, extraB, scalarA, scalarB), OPS_A(PAIRWISE_TRANSFORM_OPS), OPS_B(SCALAR_OPS))
|
|
#endif
|
|
|
|
namespace functions {
|
|
namespace grid {
|
|
template <typename T>
|
|
__device__ __noinline__ T invertedOpExecutorA(const int opTypeA, const int opNumA, const int opTypeB, const int opNumB, T x, T y, T *extras);
|
|
|
|
template <typename T>
|
|
__device__ __noinline__ T execute_2OE(const int opType, const int opNum, T x, T y, T *extras);
|
|
|
|
template <typename T>
|
|
__device__ __noinline__ T execute_1OE(const int opType, const int opNum, T x, T *extras);
|
|
|
|
|
|
__device__ __noinline__ void _ind2subC(int rank, Nd4jLong *shape, Nd4jLong idx, Nd4jLong length, Nd4jLong *coords) {
|
|
shape::ind2subC(rank, shape, idx, length, coords);
|
|
}
|
|
|
|
__device__ __noinline__ void _ind2subC(int rank, Nd4jLong *shape, Nd4jLong idx, Nd4jLong *coords) {
|
|
shape::ind2subC(rank, shape, idx, coords);
|
|
}
|
|
|
|
__device__ __noinline__ Nd4jLong _getOffset(Nd4jLong offset, Nd4jLong *shape, Nd4jLong *stride, Nd4jLong *coords, int rank) {
|
|
return shape::getOffset(offset, shape, stride, coords, rank);
|
|
}
|
|
|
|
__device__ __noinline__ Nd4jLong* _shapeOf(Nd4jLong *shape) {
|
|
return shape::shapeOf(shape);
|
|
}
|
|
|
|
__device__ __noinline__ Nd4jLong* _stride(Nd4jLong *shape) {
|
|
return shape::stride(shape);
|
|
}
|
|
|
|
__device__ __noinline__ int _rank(Nd4jLong* shape) {
|
|
return shape::rank(shape);
|
|
}
|
|
|
|
/**
|
|
* This method is able to execute various ops that takes 2 operands (x, y) + extras
|
|
* @tparam T
|
|
*/
|
|
template <typename T>
|
|
__device__ __noinline__ T execute_2OE(const int opType, const int opNum, T x, T y, T *extras) {
|
|
T z;
|
|
|
|
switch(opType) {
|
|
case 2: {
|
|
EXECUTE_NOE((x, y, extras), OPS_A(PAIRWISE_TRANSFORM_OPS));
|
|
};
|
|
break;
|
|
default: {
|
|
PRINT_FIRST("Unknown opType provided: [%i]\n", opType);
|
|
}
|
|
break;
|
|
}
|
|
|
|
return z;
|
|
}
|
|
|
|
|
|
/**
|
|
* This method is able to execute various ops that takes 1 operand (x) + extras
|
|
* @tparam T
|
|
*/
|
|
template <typename T>
|
|
__device__ __noinline__ T execute_1OE(const int opType, const int opNum, T x, T *extras) {
|
|
T z;
|
|
|
|
switch(opType) {
|
|
case 0: {
|
|
EXECUTE_NOE((x, extras), OPS_A(SCALAR_OPS));
|
|
}
|
|
break;
|
|
default: {
|
|
PRINT_FIRST("Unknown opType provided: [%i]\n", opType);
|
|
}
|
|
break;
|
|
}
|
|
|
|
return z;
|
|
}
|
|
|
|
template <typename T>
|
|
__device__ __noinline__ T invertedOpExecutorA(const int opTypeA, const int opNumA, const int opTypeB, const int opNumB, T x, T y, T *extras) {
|
|
// this code is basically InvertedMetaOp, reorganized to suit per-type execution
|
|
|
|
auto wrap = reinterpret_cast<Nd4jPointer *> (extras);
|
|
auto paramsA = reinterpret_cast<T *> (wrap[0]);
|
|
auto paramsB = reinterpret_cast<T *> (wrap[1]);
|
|
T intermediate;
|
|
|
|
// Executing first op, opA
|
|
intermediate = functions::grid::execute_2OE<T>(opTypeA, opNumA, x, y, paramsA);
|
|
|
|
// Executing second op, opB
|
|
T intermediate2 = functions::grid::execute_1OE<T>(opTypeB, opNumB, intermediate, paramsB);
|
|
|
|
//printf("X: [%f]; Y: [%f]; I0: [%f]; Z: [%f];\n", (float) x, (float) y, (float) intermediate, (float) intermediate2);
|
|
|
|
// just returning result now
|
|
return intermediate2;
|
|
}
|
|
|
|
template<typename T>
|
|
__device__ void GRIDShaped<T>::transformCuda(int opTypeA, int opNumA, int opTypeB, int opNumB, T *dx, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *result, Nd4jLong *resultShapeBuffer, T *extraParams, int *allocationPointer, UnifiedSharedMemory *manager, Nd4jLong *tadOnlyShapeInfo) {
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
__shared__ int xRank;
|
|
__shared__ int yRank;
|
|
__shared__ int resultRank;
|
|
__shared__ Nd4jLong n;
|
|
|
|
__shared__ Nd4jLong *xShape;
|
|
__shared__ Nd4jLong *yShape;
|
|
__shared__ Nd4jLong *zShape;
|
|
|
|
__shared__ Nd4jLong *xStride;
|
|
__shared__ Nd4jLong *yStride;
|
|
__shared__ Nd4jLong *zStride;
|
|
|
|
if (threadIdx.x == 0) {
|
|
xRank = _rank(xShapeBuffer);
|
|
yRank = _rank(yShapeBuffer);
|
|
resultRank = _rank(resultShapeBuffer);
|
|
n = shape::length(xShapeBuffer);
|
|
|
|
xShape = _shapeOf(xShapeBuffer);
|
|
yShape = _shapeOf(yShapeBuffer);
|
|
|
|
if (dx != result) {
|
|
zShape = _shapeOf(resultShapeBuffer);
|
|
zStride = _stride(resultShapeBuffer);
|
|
}
|
|
|
|
xStride = _stride(xShapeBuffer);
|
|
yStride = _stride(yShapeBuffer);
|
|
}
|
|
__syncthreads();
|
|
|
|
if (dx == result) {
|
|
Nd4jLong xCoord[MAX_RANK];
|
|
Nd4jLong yCoord[MAX_RANK];
|
|
|
|
for (Nd4jLong i = tid; i < n; i += gridDim.x * blockDim.x) {
|
|
_ind2subC(xRank, xShape, i, n, xCoord);
|
|
_ind2subC(yRank, yShape, i, n, yCoord);
|
|
|
|
auto xOffset = _getOffset(0, xShape, xStride, xCoord, xRank);
|
|
auto yOffset = _getOffset(0, yShape, yStride, yCoord, yRank);
|
|
result[xOffset] = functions::grid::invertedOpExecutorA<T>(opTypeA, opNumA, opTypeB, opNumB, dx[xOffset], y[yOffset], extraParams); //OpType::op(dx[xOffset], y[yOffset], extraParams);
|
|
}
|
|
|
|
} else {
|
|
Nd4jLong xCoord[MAX_RANK];
|
|
Nd4jLong yCoord[MAX_RANK];
|
|
Nd4jLong resultCoord[MAX_RANK];
|
|
|
|
for (Nd4jLong i = tid; i < n; i += gridDim.x * blockDim.x) {
|
|
_ind2subC(xRank, xShape, i, n, xCoord);
|
|
_ind2subC(yRank, yShape, i, n, yCoord);
|
|
_ind2subC(resultRank, zShape, i, n, resultCoord);
|
|
|
|
auto xOffset = _getOffset(0, xShape, xStride, xCoord, xRank);
|
|
auto yOffset = _getOffset(0, yShape, yStride, yCoord, yRank);
|
|
auto resultOffset = _getOffset(0, zShape, zStride, resultCoord, resultRank);
|
|
result[resultOffset] = functions::grid::invertedOpExecutorA<T>(opTypeA, opNumA, opTypeB, opNumB, dx[xOffset], y[yOffset], extraParams); //OpType::op(dx[xOffset], y[yOffset], extraParams);
|
|
}
|
|
}
|
|
}
|
|
|
|
template<typename T>
|
|
template<typename OpType>
|
|
__device__ void GRIDShaped<T>::transformCuda(T *dx, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *result, Nd4jLong *resultShapeBuffer, T *extraParams, int *allocationPointer, UnifiedSharedMemory *manager, Nd4jLong *tadOnlyShapeInfo) {
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
__shared__ int xRank;
|
|
__shared__ int yRank;
|
|
__shared__ int resultRank;
|
|
__shared__ Nd4jLong n;
|
|
|
|
__shared__ Nd4jLong *xShape;
|
|
__shared__ Nd4jLong *yShape;
|
|
__shared__ Nd4jLong *zShape;
|
|
|
|
__shared__ Nd4jLong *xStride;
|
|
__shared__ Nd4jLong *yStride;
|
|
__shared__ Nd4jLong *zStride;
|
|
|
|
if (threadIdx.x == 0) {
|
|
xRank = _rank(xShapeBuffer);
|
|
yRank = _rank(yShapeBuffer);
|
|
resultRank = _rank(resultShapeBuffer);
|
|
n = shape::length(xShapeBuffer);
|
|
|
|
xShape = _shapeOf(xShapeBuffer);
|
|
yShape = _shapeOf(yShapeBuffer);
|
|
|
|
if (dx != result) {
|
|
zShape = _shapeOf(resultShapeBuffer);
|
|
zStride = _stride(resultShapeBuffer);
|
|
}
|
|
|
|
xStride = _stride(xShapeBuffer);
|
|
yStride = _stride(yShapeBuffer);
|
|
}
|
|
__syncthreads();
|
|
|
|
if (dx == result) {
|
|
Nd4jLong xCoord[MAX_RANK];
|
|
Nd4jLong yCoord[MAX_RANK];
|
|
|
|
for (Nd4jLong i = tid; i < n; i += gridDim.x * blockDim.x) {
|
|
_ind2subC(xRank, xShape, i, n, xCoord);
|
|
_ind2subC(yRank, yShape, i, n, yCoord);
|
|
|
|
auto xOffset = _getOffset(0, xShape, xStride, xCoord, xRank);
|
|
auto yOffset = _getOffset(0, yShape, yStride, yCoord, yRank);
|
|
result[xOffset] = OpType::op(dx[xOffset], y[yOffset], extraParams);
|
|
}
|
|
} else {
|
|
Nd4jLong xCoord[MAX_RANK];
|
|
Nd4jLong yCoord[MAX_RANK];
|
|
Nd4jLong resultCoord[MAX_RANK];
|
|
|
|
for (Nd4jLong i = tid; i < n; i += gridDim.x * blockDim.x) {
|
|
_ind2subC(xRank, xShape, i, n, xCoord);
|
|
_ind2subC(yRank, yShape, i, n, yCoord);
|
|
_ind2subC(resultRank, zShape, i, n, resultCoord);
|
|
|
|
auto xOffset = _getOffset(0, xShape, xStride, xCoord, xRank);
|
|
auto yOffset = _getOffset(0, yShape, yStride, yCoord, yRank);
|
|
auto resultOffset = _getOffset(0, zShape, zStride, resultCoord, resultRank);
|
|
result[resultOffset] = OpType::op(dx[xOffset], y[yOffset], extraParams);
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
template <>
|
|
void GRIDShaped<float>::execMetaPredicateShaped(cudaStream_t * stream, Nd4jPointer *extras, const int opTypeA, const int opNumA, const int opTypeB, const int opNumB, Nd4jLong N, float *dx, Nd4jLong *xShapeInfo, float *dy, Nd4jLong *yShapeInfo, float *dz, Nd4jLong *zShapeInfo, float *extraA, float *extraB, float scalarA, float scalarB) {
|
|
invertedMetaPairwiseShapedNumericFloat<<<128, 1024, 2048, *stream>>>(opTypeA, opNumA, opTypeB, opNumB, N, dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, extraA, extraB, scalarA, scalarB);
|
|
|
|
DEBUG_KERNEL(stream, opNumA);
|
|
}
|
|
|
|
template <>
|
|
void GRIDShaped<float16>::execMetaPredicateShaped(cudaStream_t * stream, Nd4jPointer *extras, const int opTypeA, const int opNumA, const int opTypeB, const int opNumB, Nd4jLong N, float16 *dx, Nd4jLong *xShapeInfo, float16 *dy, Nd4jLong *yShapeInfo, float16 *dz, Nd4jLong *zShapeInfo, float16 *extraA, float16 *extraB, float16 scalarA, float16 scalarB) {
|
|
invertedMetaPairwiseShapedNumericHalf<<<128, 1024, 2048, *stream>>>(opTypeA, opNumA, opTypeB, opNumB, N, dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, extraA, extraB, scalarA, scalarB);
|
|
|
|
DEBUG_KERNEL(stream, opNumB);
|
|
}
|
|
|
|
template <>
|
|
void GRIDShaped<double>::execMetaPredicateShaped(cudaStream_t * stream, Nd4jPointer *extras, const int opTypeA, const int opNumA, const int opTypeB, const int opNumB, Nd4jLong N, double *dx, Nd4jLong *xShapeInfo, double *dy, Nd4jLong *yShapeInfo, double *dz, Nd4jLong *zShapeInfo, double *extraA, double *extraB, double scalarA, double scalarB) {
|
|
invertedMetaPairwiseShapedNumericDouble<<<128, 1024, 2048, *stream>>>(opTypeA, opNumA, opTypeB, opNumB, N, dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, extraA, extraB, scalarA, scalarB);
|
|
|
|
DEBUG_KERNEL(stream, opNumA);
|
|
}
|
|
}
|
|
} |