/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ /* * broadcasting.h * * Created on: Dec 28, 2015 * Author: agibsonccc */ #ifndef BROADCASTING_BOOL_H_ #define BROADCASTING_BOOL_H_ #include #include #include #include #include #include #include #ifdef __CUDACC__ #include #include #endif #ifdef __JNI__ #include #endif #include #include "legacy_ops.h" namespace functions { namespace broadcast { /** * Broadcast operation * for broadcasting a smaller tensor * along long a bigger one. */ template class BroadcastBool { public: #ifdef __CUDACC__ template static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); template static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, void *extraParams); template static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *result, Nd4jLong const* resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ); template static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, void *extraParams); static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *result, Nd4jLong const* resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ); static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, void *extraParams); template static __device__ void transformInverseCuda(const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); template static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); #else static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, uint64_t start, uint64_t stop); static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, void *extraParams); static void execInverse(int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, uint64_t start, uint64_t stop); /** * CPU execution * @param x the input * @param xShapeInfo the x shape information * @param y the y data * @param yShapeInfo the y shape information * @param result the result * @param resultShapeInfo the result shape information * @param dimension the dimension to broadcast along long * @param dimensionLength the length of the dimension buffer */ template static void exec(const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, uint64_t start, uint64_t stop); template static void exec(const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, void *extraParams); template static void execInverse(const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, uint64_t start, uint64_t stop); #endif }; } } #endif /* BROADCASTING_H_ */