2021-02-01 13:31:45 +01:00
|
|
|
/* ******************************************************************************
|
|
|
|
*
|
2019-06-06 14:21:15 +02:00
|
|
|
*
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
*
|
2021-02-01 13:31:45 +01:00
|
|
|
* See the NOTICE file distributed with this work for additional
|
|
|
|
* information regarding copyright ownership.
|
2019-06-06 14:21:15 +02:00
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
* under the License.
|
|
|
|
*
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
//
|
|
|
|
// @author raver119@gmail.com
|
|
|
|
//
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
#include <system/Environment.h>
|
2019-06-06 14:21:15 +02:00
|
|
|
#include <loops/transform_same.h>
|
2020-03-02 10:49:41 +01:00
|
|
|
#include <system/op_boilerplate.h>
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
#include <loops/legacy_ops.h>
|
|
|
|
#include <helpers/DebugHelper.h>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
__device__ void transformGeneric(
|
|
|
|
int opNum,
|
|
|
|
Nd4jLong n,
|
|
|
|
T *dy,
|
|
|
|
Nd4jLong incy,
|
|
|
|
T *params,
|
|
|
|
T *result,
|
|
|
|
Nd4jLong resultStride, int *allocationPointer, T *reductionPointer) {
|
|
|
|
|
|
|
|
functions::transform::Transform<T>::transformCuda(
|
|
|
|
opNum,
|
|
|
|
n,
|
|
|
|
dy,
|
|
|
|
incy,
|
|
|
|
params,
|
|
|
|
result,
|
|
|
|
resultStride,
|
|
|
|
allocationPointer,
|
|
|
|
reductionPointer,
|
|
|
|
nullptr);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T, typename OpClass>
|
|
|
|
__device__ void transformSimpleGeneric(
|
|
|
|
Nd4jLong n,
|
|
|
|
T *dy,
|
|
|
|
Nd4jLong incy,
|
|
|
|
T *params,
|
|
|
|
T *result,
|
|
|
|
Nd4jLong resultStride, int *allocationPointer, T *reductionPointer) {
|
|
|
|
|
|
|
|
functions::transform::Transform<T>::template transformCuda<OpClass>(
|
|
|
|
n,
|
|
|
|
dy,
|
|
|
|
incy,
|
|
|
|
params,
|
|
|
|
result,
|
|
|
|
resultStride,
|
|
|
|
allocationPointer,
|
|
|
|
reductionPointer,
|
|
|
|
nullptr);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
__device__ void transformGeneric(
|
|
|
|
int opNum,
|
|
|
|
T *dy,
|
|
|
|
Nd4jLong *xShapeInfo, int xRank,
|
|
|
|
T *params,
|
|
|
|
T *result,Nd4jLong *resultShapeInfo, int zRank, int *allocationPointer, T *reductionPointer) {
|
|
|
|
|
|
|
|
functions::transform::Transform<T>::transformCuda(
|
|
|
|
opNum,
|
|
|
|
dy,
|
|
|
|
xShapeInfo,
|
|
|
|
params,
|
|
|
|
result,
|
|
|
|
resultShapeInfo,
|
|
|
|
allocationPointer,
|
|
|
|
reductionPointer,
|
|
|
|
nullptr);
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T, typename OpClass>
|
|
|
|
__device__ void transformSimpleGeneric(
|
|
|
|
T *dy,
|
|
|
|
Nd4jLong *xShapeInfo, int xRank,
|
|
|
|
T *params,
|
|
|
|
T *result, Nd4jLong *resultShapeInfo, int zRank, int *allocationPointer, T *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
|
|
|
|
|
|
|
__shared__ UnifiedSharedMemory *manager;
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
extern __shared__ unsigned char shmem[];
|
|
|
|
manager = new(shmem) UnifiedSharedMemory((int *) shmem);
|
|
|
|
manager->init(sizeof(UnifiedSharedMemory), 0, sizeof(functions::transform::Transform<T>), sizeof(shape::TAD), xRank);
|
|
|
|
}
|
|
|
|
__syncthreads();
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
functions::transform::Transform<T>::template transformCuda<OpClass>(
|
|
|
|
dy,
|
|
|
|
xShapeInfo,
|
|
|
|
params,
|
|
|
|
result,
|
|
|
|
resultShapeInfo,
|
|
|
|
allocationPointer,
|
|
|
|
reductionPointer,
|
|
|
|
manager, tadShapeInfo, tadOffsets);
|
|
|
|
}
|
|
|
|
|
|
|
|
// transform strided
|
|
|
|
DISPATCH_KERNEL_SIMPLE(transformStrided_, transformSimpleGeneric, float, INPUT(Nd4jLong n, float *x, Nd4jLong xStride, float *extraParams, float *z, Nd4jLong zStride, int *allocationPointer, float *reductionPointer), PARAMS(n, x, xStride, extraParams, z, zStride, allocationPointer, reductionPointer), OPS_A(TRANSFORM_OPS))
|
|
|
|
DISPATCH_KERNEL_SIMPLE(transformStrided_, transformSimpleGeneric, double, INPUT(Nd4jLong n, double *x, Nd4jLong xStride, double *extraParams, double *z, Nd4jLong zStride, int *allocationPointer, double *reductionPointer), PARAMS(n, x, xStride, extraParams, z, zStride, allocationPointer, reductionPointer), OPS_A(TRANSFORM_OPS))
|
|
|
|
DISPATCH_KERNEL_SIMPLE(transformStrided_, transformSimpleGeneric, float16, INPUT(Nd4jLong n, float16 *x, Nd4jLong xStride, float16 *extraParams, float16 *z, Nd4jLong zStride, int *allocationPointer, float16 *reductionPointer), PARAMS(n, x, xStride, extraParams, z, zStride, allocationPointer, reductionPointer), OPS_A(TRANSFORM_OPS))
|
|
|
|
|
|
|
|
// transform shaped
|
|
|
|
DISPATCH_KERNEL_SIMPLE(transformShaped_, transformSimpleGeneric, float, INPUT(float *x, Nd4jLong *xShape, int xRank, float *extraParams, float *z, Nd4jLong *zShape, int zRank, int *allocationPointer, float *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets), PARAMS(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(TRANSFORM_OPS))
|
|
|
|
DISPATCH_KERNEL_SIMPLE(transformShaped_, transformSimpleGeneric, double, INPUT(double *x, Nd4jLong *xShape, int xRank, double *extraParams, double *z, Nd4jLong *zShape, int zRank, int *allocationPointer, double *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets), PARAMS(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(TRANSFORM_OPS))
|
|
|
|
DISPATCH_KERNEL_SIMPLE(transformShaped_, transformSimpleGeneric, float16, INPUT(float16 *x, Nd4jLong *xShape, int xRank, float16 *extraParams, float16 *z, Nd4jLong *zShape, int zRank, int *allocationPointer, float16 *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets), PARAMS(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(TRANSFORM_OPS))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace functions {
|
|
|
|
namespace transform {
|
|
|
|
|
|
|
|
template <>
|
|
|
|
_CUDA_H void Transform<float>::executeTransformStrided(dim3 launchDims, cudaStream_t *stream, int opNum, Nd4jLong n, float *x, Nd4jLong xStride, float *extraParams, float *z, Nd4jLong zStride, int *allocationPointer, float *reductionPointer) {
|
|
|
|
DISPATCH_SIMPLE(transformStrided, float, PARAMS(n, x, xStride, extraParams, z, zStride, allocationPointer, reductionPointer), OPS_A(TRANSFORM_OPS))
|
|
|
|
|
|
|
|
DEBUG_KERNEL(stream, opNum);
|
|
|
|
};
|
|
|
|
|
|
|
|
template <>
|
|
|
|
_CUDA_H void Transform<double>::executeTransformStrided(dim3 launchDims, cudaStream_t *stream, int opNum, Nd4jLong n, double *x, Nd4jLong xStride, double *extraParams, double *z, Nd4jLong zStride, int *allocationPointer, double *reductionPointer) {
|
|
|
|
DISPATCH_SIMPLE(transformStrided, double, PARAMS(n, x, xStride, extraParams, z, zStride, allocationPointer, reductionPointer), OPS_A(TRANSFORM_OPS))
|
|
|
|
|
|
|
|
DEBUG_KERNEL(stream, opNum);
|
|
|
|
};
|
|
|
|
|
|
|
|
template <>
|
|
|
|
_CUDA_H void Transform<float16>::executeTransformStrided(dim3 launchDims, cudaStream_t *stream, int opNum, Nd4jLong n, float16 *x, Nd4jLong xStride, float16 *extraParams, float16 *z, Nd4jLong zStride, int *allocationPointer, float16 *reductionPointer) {
|
|
|
|
DISPATCH_SIMPLE(transformStrided, float16, PARAMS(n, x, xStride, extraParams, z, zStride, allocationPointer, reductionPointer), OPS_A(TRANSFORM_OPS))
|
|
|
|
|
|
|
|
DEBUG_KERNEL(stream, opNum);
|
|
|
|
};
|
|
|
|
|
|
|
|
template <>
|
|
|
|
_CUDA_H void Transform<float>::executeTransformShaped(dim3 launchDims, cudaStream_t *stream, int opNum, float *x, Nd4jLong *xShape, int xRank, float *extraParams, float *z, Nd4jLong *zShape, int zRank, int *allocationPointer, float *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
DISPATCH_SIMPLE(transformShaped, float, PARAMS(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(TRANSFORM_OPS))
|
|
|
|
|
|
|
|
|
|
|
|
DEBUG_KERNEL(stream, opNum);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
_CUDA_H void Transform<float16>::executeTransformShaped(dim3 launchDims, cudaStream_t *stream, int opNum, float16 *x, Nd4jLong *xShape, int xRank, float16 *extraParams, float16 *z, Nd4jLong *zShape, int zRank, int *allocationPointer, float16 *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
DISPATCH_SIMPLE(transformShaped, float16, PARAMS(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(TRANSFORM_OPS))
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-06-06 14:26:55 +02:00
|
|
|
if (sd::Environment::getInstance().isDebug())
|
2019-06-06 14:21:15 +02:00
|
|
|
checkCudaErrors(cudaStreamSynchronize(*stream));
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
_CUDA_H void Transform<double>::executeTransformShaped(dim3 launchDims, cudaStream_t *stream, int opNum, double *x, Nd4jLong *xShape, int xRank, double *extraParams, double *z, Nd4jLong *zShape, int zRank, int *allocationPointer, double *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
DISPATCH_SIMPLE(transformShaped, double, PARAMS(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(TRANSFORM_OPS))
|
|
|
|
|
|
|
|
DEBUG_KERNEL(stream, opNum);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename OpType>
|
|
|
|
__device__ void Transform<T>::transformCuda(
|
|
|
|
T *dy,
|
|
|
|
Nd4jLong *shapeInfo,
|
|
|
|
T *params,
|
|
|
|
T *result,
|
|
|
|
Nd4jLong *resultShapeInfo,
|
|
|
|
int *allocationPointer, T *reductionPointer, UnifiedSharedMemory *manager, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
|
|
|
|
|
|
|
if(OpType::requiresSpecial) {
|
|
|
|
OpType::execSpecialCuda(dy,shapeInfo,result,resultShapeInfo,params, allocationPointer, reductionPointer, manager, tadShapeInfo, tadOffsets);
|
|
|
|
return;
|
|
|
|
} else {
|
|
|
|
|
|
|
|
auto xShape = shape::shapeOf(shapeInfo);
|
|
|
|
auto xStride = shape::stride(shapeInfo);
|
|
|
|
auto xOrder = shape::order(shapeInfo);
|
|
|
|
auto resultOrder = shape::order(resultShapeInfo);
|
|
|
|
auto xRank = shape::rank(shapeInfo);
|
|
|
|
|
|
|
|
auto xElementWiseStride = shape::elementWiseStride(shapeInfo);
|
|
|
|
auto resultElementWiseStride = shape::elementWiseStride(resultShapeInfo);
|
|
|
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
|
|
|
__shared__ Nd4jLong length;
|
|
|
|
if(threadIdx.x == 0)
|
|
|
|
length = shape::length(shapeInfo);
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
if(xElementWiseStride >= 1 && resultElementWiseStride >= 1 && xOrder == resultOrder) {
|
|
|
|
transformCuda<OpType>(
|
|
|
|
length,
|
|
|
|
dy,
|
|
|
|
xElementWiseStride,
|
|
|
|
params,
|
|
|
|
result,
|
|
|
|
resultElementWiseStride, allocationPointer, reductionPointer, manager);
|
|
|
|
}
|
|
|
|
else {
|
|
|
|
Nd4jLong xCoord[MAX_RANK];
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) {
|
|
|
|
shape::ind2sub(xRank,shape::shapeOf(shapeInfo),i, length, xCoord);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
|
|
|
auto xOffset2 = shape::getOffset(shapeInfo, xCoord);
|
|
|
|
auto resultOffset2 = shape::getOffset(resultShapeInfo, xCoord);
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
result[resultOffset2] = OpType::op(dy[xOffset2], params);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename OpType>
|
|
|
|
__device__ void Transform<T>::transformCuda(
|
|
|
|
Nd4jLong n,
|
|
|
|
T *dy,
|
|
|
|
Nd4jLong incy,
|
|
|
|
T *params,
|
|
|
|
T *result,
|
|
|
|
Nd4jLong resultStride,
|
|
|
|
int *allocationPointer, T *reductionPointer, UnifiedSharedMemory *manager) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
int totalThreads = gridDim.x * blockDim.x;
|
|
|
|
Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
|
|
|
if(incy == 1 && resultStride == 1) {
|
|
|
|
/* equal, positive, non-unit increments. */
|
|
|
|
for (; i < n; i += totalThreads) {
|
|
|
|
result[i] = OpType::op(dy[i], params);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
else {
|
|
|
|
for (; i < n; i += totalThreads) {
|
|
|
|
result[i * resultStride] = OpType::op(dy[i * incy], params);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
__device__ void Transform<T>::transformCuda(
|
|
|
|
const int opNum,
|
|
|
|
T *dy,
|
|
|
|
Nd4jLong *shapeInfo,
|
|
|
|
T *params,
|
|
|
|
T *result,
|
|
|
|
Nd4jLong *resultShapeInfo,
|
|
|
|
int *allocationPointer,
|
|
|
|
T *reductionPointer,
|
|
|
|
UnifiedSharedMemory *manager, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
|
|
|
DISPATCH_BY_OPNUM(transformCuda, PARAMS(dy, shapeInfo, params, result, resultShapeInfo, allocationPointer, reductionPointer, manager, tadShapeInfo, tadOffsets), TRANSFORM_OPS);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
__device__ void Transform<T>::transformCuda(
|
|
|
|
const int opNum,
|
|
|
|
Nd4jLong n,
|
|
|
|
T *dy,
|
|
|
|
Nd4jLong incy,
|
|
|
|
T *params,
|
|
|
|
T *result,
|
|
|
|
Nd4jLong resultStride,
|
|
|
|
int *allocationPointer,
|
|
|
|
T *reductionPointer,
|
|
|
|
UnifiedSharedMemory *manager) {
|
|
|
|
DISPATCH_BY_OPNUM(transformCuda, PARAMS(n, dy, incy, params, result, resultStride, allocationPointer, reductionPointer, manager), TRANSFORM_OPS);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
//template class ND4J_EXPORT Transform<float>;
|
|
|
|
//template class ND4J_EXPORT Transform<float16>;
|
|
|
|
//template class ND4J_EXPORT Transform<double>;
|
|
|
|
|
|
|
|
BUILD_CALL_1(template __device__ void Transform<float>::transformCuda, float, (float*, Nd4jLong*, float*, float*,Nd4jLong*, int*,float*, UnifiedSharedMemory*, Nd4jLong*, Nd4jLong*), TRANSFORM_OPS)
|
|
|
|
BUILD_CALL_1(template __device__ void Transform<float16>::transformCuda, float16, (float16*, Nd4jLong*, float16*, float16*,Nd4jLong*, int*, float16*, UnifiedSharedMemory*, Nd4jLong*, Nd4jLong*), TRANSFORM_OPS)
|
|
|
|
BUILD_CALL_1(template __device__ void Transform<double>::transformCuda, double, (double*, Nd4jLong*, double*, double*,Nd4jLong*, int*, double*, UnifiedSharedMemory*, Nd4jLong*, Nd4jLong*), TRANSFORM_OPS)
|
|
|
|
}
|
|
|
|
}
|