################################################################################ # Copyright (c) 2015-2018 Skymind, Inc. # # This program and the accompanying materials are made available under the # terms of the Apache License, Version 2.0 which is available at # https://www.apache.org/licenses/LICENSE-2.0. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. # # SPDX-License-Identifier: Apache-2.0 ################################################################################ #ifndef NDARRAY_MACRO #define NDARRAY_MACRO #include //NDArray *other, T *extraParams BUILD_CALL_1(template void NDArray::template applyPairwiseTransform, float, (NDArray* other, float* extraParams), PAIRWISE_TRANSFORM_OPS) BUILD_CALL_1(template void NDArray::applyPairwiseTransform, float16, (NDArray* other, float16* extraParams), PAIRWISE_TRANSFORM_OPS) BUILD_CALL_1(template void NDArray::applyPairwiseTransform, double, (NDArray* other, double* extraParams), PAIRWISE_TRANSFORM_OPS) // NDArray *other, NDArray *target, T *extraParams BUILD_CALL_1(template void nd4j::NDArray::applyPairwiseTransform, float, (NDArray* other, NDArray* target, float* extraParams), PAIRWISE_TRANSFORM_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyPairwiseTransform, float16, (NDArray* other, NDArray* target, float16* extraParams), PAIRWISE_TRANSFORM_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyPairwiseTransform, double, (NDArray* other, NDArray* target, double* extraParams), PAIRWISE_TRANSFORM_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyScalar, float16, (NDArray& scalar, NDArray* target, float16 *extraParams) const, SCALAR_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyScalar, float16, (float16 scalar, NDArray* target, float16 *extraParams) const, SCALAR_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyScalar, float, (NDArray& scalar, NDArray* target, float *extraParams) const, SCALAR_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyScalar, float, (float scalar, NDArray* target, float *extraParams) const, SCALAR_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyScalar, double, (NDArray& scalar, NDArray* target, double *extraParams) const, SCALAR_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyScalar, double, (double scalar, NDArray* target, double *extraParams) const, SCALAR_OPS) BUILD_CALL_1(template float16 nd4j::NDArray::reduceNumber, float16, (float16 *extraParams) const, REDUCE_OPS) BUILD_CALL_1(template float nd4j::NDArray::reduceNumber, float, (float *extraParams) const, REDUCE_OPS) BUILD_CALL_1(template double nd4j::NDArray::reduceNumber, double, (double *extraParams) const, REDUCE_OPS) BUILD_CALL_1(template Nd4jLong nd4j::NDArray::indexReduceNumber, float16, (float16 *extraParams), INDEX_REDUCE_OPS) BUILD_CALL_1(template Nd4jLong nd4j::NDArray::indexReduceNumber, float, (float *extraParams), INDEX_REDUCE_OPS) BUILD_CALL_1(template Nd4jLong nd4j::NDArray::indexReduceNumber, double, (double *extraParams), INDEX_REDUCE_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyBroadcast, float16, (std::initializer_list list, const nd4j::NDArray* a, nd4j::NDArray* b, float16* c), BROADCAST_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyBroadcast, float, (std::initializer_list list, const nd4j::NDArray* a, nd4j::NDArray* b, float* c), BROADCAST_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyBroadcast, double, (std::initializer_list list, const nd4j::NDArray* a, nd4j::NDArray* b, double* c), BROADCAST_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyTrueBroadcast, float16,(const nd4j::NDArray* a, nd4j::NDArray* target, const bool checkTargetShape, float16* c) const, BROADCAST_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyTrueBroadcast, float, (const nd4j::NDArray* a, nd4j::NDArray* target, const bool checkTargetShape, float* c) const, BROADCAST_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyTrueBroadcast, double, (const nd4j::NDArray* a, nd4j::NDArray* target, const bool checkTargetShape, double* c) const, BROADCAST_OPS) BUILD_CALL_1(template nd4j::NDArray* nd4j::NDArray::applyTrueBroadcast, float16, (const nd4j::NDArray* a, float16* c) const, BROADCAST_OPS) BUILD_CALL_1(template nd4j::NDArray* nd4j::NDArray::applyTrueBroadcast, float, (const nd4j::NDArray* a, float* c) const, BROADCAST_OPS) BUILD_CALL_1(template nd4j::NDArray* nd4j::NDArray::applyTrueBroadcast, double, (const nd4j::NDArray* a, double* c) const, BROADCAST_OPS) BUILD_CALL_1(template nd4j::NDArray nd4j::NDArray::applyTrueBroadcast, float16, (const nd4j::NDArray& a, float16* c) const, BROADCAST_OPS) BUILD_CALL_1(template nd4j::NDArray nd4j::NDArray::applyTrueBroadcast, float, (const nd4j::NDArray& a, float* c) const, BROADCAST_OPS) BUILD_CALL_1(template nd4j::NDArray nd4j::NDArray::applyTrueBroadcast, double, (const nd4j::NDArray& a, double* c) const, BROADCAST_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyTransform, float16, (NDArray* target, float16* extraParams), TRANSFORM_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyTransform, float, (NDArray* target, float* extraParams), TRANSFORM_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyTransform, double, (NDArray* target, double* extraParams), TRANSFORM_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyTransform, float16, (float16* extraParams), TRANSFORM_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyTransform, float, (float* extraParams), TRANSFORM_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyTransform, double, (double* extraParams), TRANSFORM_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyRandom, float16, (nd4j::random::RandomBuffer *buffer, NDArray* y, NDArray* z, float16* extraParams), RANDOM_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyRandom, float, (nd4j::random::RandomBuffer *buffer, NDArray* y, NDArray* z, float* extraParams), RANDOM_OPS) BUILD_CALL_1(template void nd4j::NDArray::applyRandom, double, (nd4j::random::RandomBuffer *buffer, NDArray* y, NDArray* z, double* extraParams), RANDOM_OPS) BUILD_CALL_1(template NDArray nd4j::NDArray::transform, float16, (float16* extraParams) const, TRANSFORM_OPS) BUILD_CALL_1(template NDArray nd4j::NDArray::transform, float, (float* extraParams) const, TRANSFORM_OPS) BUILD_CALL_1(template NDArray nd4j::NDArray::transform, double, (double* extraParams) const, TRANSFORM_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template reduceAlongDimension, float, (const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template reduceAlongDimension, float16, (const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template reduceAlongDimension, double, (const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS) BUILD_CALL_1(template NDArray nd4j::NDArray::template reduceAlongDims, float, (const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS) BUILD_CALL_1(template NDArray nd4j::NDArray::template reduceAlongDims, float16, (const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS) BUILD_CALL_1(template NDArray nd4j::NDArray::template reduceAlongDims, double, (const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template reduceAlongDimension, float, (const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template reduceAlongDimension, float16, (const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template reduceAlongDimension, double, (const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS) BUILD_CALL_1(template void nd4j::NDArray::template reduceAlongDimension, float, (NDArray* target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, float * extras) const, REDUCE_OPS) BUILD_CALL_1(template void nd4j::NDArray::template reduceAlongDimension, float16, (NDArray* target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, float16 * extras) const, REDUCE_OPS) BUILD_CALL_1(template void nd4j::NDArray::template reduceAlongDimension, double, (NDArray* target, const std::vector& dimension, const bool keepDims, const bool supportOldShapes, double * extras) const, REDUCE_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template varianceAlongDimension, float, (const bool biasCorrected, const std::initializer_list& dimensions) const, SUMMARY_STATS_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template varianceAlongDimension, float16, (const bool biasCorrected, const std::initializer_list& dimensions) const, SUMMARY_STATS_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template varianceAlongDimension, double, (const bool biasCorrected, const std::initializer_list& dimensions) const, SUMMARY_STATS_OPS) BUILD_CALL_1(template void nd4j::NDArray::template varianceAlongDimension, float, (const NDArray *target, const bool biasCorrected, const std::initializer_list& dimensions), SUMMARY_STATS_OPS) BUILD_CALL_1(template void nd4j::NDArray::template varianceAlongDimension, float16, (const NDArray *target,const bool biasCorrected, const std::initializer_list& dimensions), SUMMARY_STATS_OPS) BUILD_CALL_1(template void nd4j::NDArray::template varianceAlongDimension, double, (const NDArray *target, const bool biasCorrected, const std::initializer_list& dimensions), SUMMARY_STATS_OPS) BUILD_CALL_1(template void nd4j::NDArray::template varianceAlongDimension, float, (const NDArray *target, const bool biasCorrected, const std::vector& dimensions), SUMMARY_STATS_OPS) BUILD_CALL_1(template void nd4j::NDArray::template varianceAlongDimension, float16, (const NDArray *target,const bool biasCorrected, const std::vector& dimensions), SUMMARY_STATS_OPS) BUILD_CALL_1(template void nd4j::NDArray::template varianceAlongDimension, double, (const NDArray *target, const bool biasCorrected, const std::vector& dimensions), SUMMARY_STATS_OPS) BUILD_CALL_1(template float nd4j::NDArray::template varianceNumber, float, (bool biasCorrected), SUMMARY_STATS_OPS) BUILD_CALL_1(template float16 nd4j::NDArray::template varianceNumber, float16, (bool biasCorrected), SUMMARY_STATS_OPS) BUILD_CALL_1(template double nd4j::NDArray::template varianceNumber, double, (bool biasCorrected), SUMMARY_STATS_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template applyReduce3, float, (const NDArray* other, const float* extraParams) const, REDUCE3_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template applyReduce3, float16, (const NDArray* other, const float16* extraParams) const, REDUCE3_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template applyReduce3, double, (const NDArray* other, const double* extraParams) const, REDUCE3_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template applyReduce3, float, (const NDArray* other, const std::vector &dims, const float* extraParams) const, REDUCE3_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template applyReduce3, float16, (const NDArray* other, const std::vector &dims, const float16* extraParams) const, REDUCE3_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template applyReduce3, double, (const NDArray* other, const std::vector &dims, const double* extraParams) const, REDUCE3_OPS) BUILD_CALL_1(template void nd4j::NDArray::template applyIndexReduce, float, (const NDArray* target, const std::vector & alpha, const float* beta) const, INDEX_REDUCE_OPS) BUILD_CALL_1(template void nd4j::NDArray::template applyIndexReduce, float16, (const NDArray* target, const std::vector & alpha, const float16* beta) const, INDEX_REDUCE_OPS) BUILD_CALL_1(template void nd4j::NDArray::template applyIndexReduce, double, (const NDArray* target, const std::vector & alpha, const double* beta) const, INDEX_REDUCE_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template applyIndexReduce, float, (const std::vector & alpha, const float* beta) const, INDEX_REDUCE_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template applyIndexReduce, float16, (const std::vector & alpha, const float16* beta) const, INDEX_REDUCE_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template applyIndexReduce, double, (const std::vector & alpha, const double* beta) const, INDEX_REDUCE_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template applyAllReduce3, float, (const nd4j::NDArray* alpha, const std::vector & beta, float const* gamma) const, REDUCE3_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template applyAllReduce3, float16, (const nd4j::NDArray* alpha, const std::vector & beta, float16 const* gamma) const, REDUCE3_OPS) BUILD_CALL_1(template NDArray *nd4j::NDArray::template applyAllReduce3, double, (const nd4j::NDArray* alpha, const std::vector & beta, double const* gamma) const, REDUCE3_OPS) template NDArray mmul(const NDArray& left, const NDArray& right); template NDArray mmul(const NDArray& left, const NDArray& right); template NDArray mmul(const NDArray& left, const NDArray& right); // template NDArray operator-(const float, const NDArray&); // template NDArray operator-(const float16, const NDArray&); // template NDArray operator-(const double, const NDArray&); // template NDArray operator+(const float, const NDArray&); // template NDArray operator+(const float16, const NDArray&); // template NDArray operator+(const double, const NDArray&); #endif