[WIP] build time improvements (#106)
* fix pad javadoc and @see links. (#72) Signed-off-by: Robert Altena <Rob@Ra-ai.com> * [WIP] More fixes (#73) * special tests for ConstantTadHelper/ConstantShapeHelper Signed-off-by: raver119 <raver119@gmail.com> * release methods for data buffers Signed-off-by: raver119 <raver119@gmail.com> * delete temporary buffer Java side Signed-off-by: raver119 <raver119@gmail.com> * delete temporary buffer Java side Signed-off-by: raver119 <raver119@gmail.com> * delete temporary TadPack C++/Java side (#74) Signed-off-by: raver119 <raver119@gmail.com> * Zoo model TF import test updates (#75) * argLine fix, update compression_gru comment * updated comment for xception * undid but commented argLine change * updated xlnet comment * copyright headers * - new NDArray methods like()/ulike() (#77) - fix for depthwise_conv2d_bp + special test Signed-off-by: raver119 <raver119@gmail.com> * upsampling2d fix CUDA Signed-off-by: raver119 <raver119@gmail.com> * DL4J trace logging (#79) * MLN/CG trace logging for debugging Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tiny tweak Signed-off-by: AlexDBlack <blacka101@gmail.com> * strided_slice_bp shape fn leak fix Signed-off-by: raver119 <raver119@gmail.com> * SameDiff fixes and naming (#78) * remove SDVariable inplace methods * import methods * npe fix in OpVal * removed SameDiff inplace ops from tests * Naming updates, moved to centralized methods in SameDiff, should use op_#:# for everything * quick fixes * javadoc * SDVariable eval with placeholders * use regex match * better matching * fix javadoc. (#76) * fix javadoc. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * replace most @see with @link s. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * 4 additional tests Signed-off-by: raver119 <raver119@gmail.com> * Various DL4J/ND4J fixes (#81) * #7954 Force refresh of UI when switching tabs on overview page Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8017 Concurrent modification exception (synchronize) fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8033 Don't initialize updater in middle of writing memory crash dump Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8208 Fix shape checks for ND4J int[] creator methods Signed-off-by: AlexDBlack <blacka101@gmail.com> * #6385 #7992 Keras import naming fixes + cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8016 Upsampling3D - add NDHWC format support Signed-off-by: AlexDBlack <blacka101@gmail.com> * Refactor NativeOps.h to export C functions * Actually export functions from NativeOps.h * Adapt the Java wrappers in ND4J generated with JavaCPP * Create C wrappers for some of the C++ classes currently used by ND4J * remove duplicate code in createBufferDetached. (#83) Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Keras model import - updater lr fix (#84) * Keras model import - updater lr fix Signed-off-by: eraly <susan.eraly@gmail.com> * Keras model import - updater lr fix, cleanup Signed-off-by: eraly <susan.eraly@gmail.com> * Fix functions of OpaqueVariablesSet * SameDiff Convolution Config validation, better output methods (#82) * Conv Config validation & tests Signed-off-by: Ryan Nett <rnett@skymind.io> * stackOutputs utility method Signed-off-by: Ryan Nett <rnett@skymind.io> * use constructor for validation, support negative kernel sizes (infered from weights) Signed-off-by: Ryan Nett <rnett@skymind.io> * better output methods Signed-off-by: Ryan Nett <rnett@skymind.io> * move output to be with fit and evaluate Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * more fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * refactor duplicate code from pad methods. (#86) * refactor duplicate code from pad methods. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * replace switch with if. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Various ND4J/DL4J fixes and improvements (#87) * Reshape and reallocate - small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Reshape and reallocate - small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * #6488 ElementWiseVertex broadcast support Signed-off-by: AlexDBlack <blacka101@gmail.com> * Constructors and broadcast supported it Transforms.max/min Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8054 ElementWiseVertex now supports broadcast inputs Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8057 Nd4j.create overload dtype fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * #7551 ND4J Shape validation fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * [WIP] Numpy boolean import (#91) * numpy bool type Signed-off-by: raver119 <raver119@gmail.com> * numpy bool java side Signed-off-by: raver119 <raver119@gmail.com> * remove create method with unused parameter. (#89) * remove create method with unused parameter. * removed more unused methods. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * removing more unused code. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * last removal of unused code. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * remove createSparse methods. (#92) Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Various ND4J/DL4J fixes (#90) * Deprecate Old*Op instances Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8063 #8054 Broadcast exceptions + cleanup inplace ops Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove bad test condition Signed-off-by: AlexDBlack <blacka101@gmail.com> * #7993 Fix shape function issue in crop_and_resize op Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J SameDiff lambda layer fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8029 Fix for pnorm backprop math Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8038 Fix Op profiler NaN/Inf triggering + add tests (#93) Signed-off-by: AlexDBlack <blacka101@gmail.com> * createUninitializedDetached refactoring. (#94) * wip * update interface, add null implementations. * Breaking one test in a weird way. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * createUninitializedDetached refactored. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * cuda build fix for issues introduced by recent refactoring Signed-off-by: raver119 <raver119@gmail.com> * [WIP] More of CUDA (#95) * initial commit Signed-off-by: raver119 <raver119@gmail.com> * Implementation of hashcode cuda helper. Working edition. * Fixed parallel test input arangements. * Fixed tests for hashcode op. * Fixed shape calculation for image:crop_and_resize op and test. * NativeOps tests. Initial test suite. * Added tests for indexReduce methods. * Added test on execBroadcast with NDArray as dimensions. * Added test on execBroadcastBool with NDArray as dimensions. * Added tests on execPairwiseTransform and execPairwiseTransofrmBool. * Added tests for execReduce with scalar results. * Added reduce tests for non-empty dims array. * Added tests for reduce3. * Added tests for execScalar. * Added tests for execSummaryStats. * - provide cpu/cuda code for batch_to_space - testing it Signed-off-by: Yurii <yurii@skymind.io> * - remove old test for batch_to_space (had wrong format and numbers were not checked) Signed-off-by: Yurii <yurii@skymind.io> * Fixed complilation errors with test. * Added test for execTransformFloat. * Added test for execTransformSame. * Added test for execTransformBool. * Added test for execTransformStrict. * Added tests for execScalar/execScalarBool with TADs. * Added test for flatten. * - provide cpu/cuda code for space_to_Batch operaion Signed-off-by: Yurii <yurii@skymind.io> * Added test for concat. * comment unnecessary stuff in s_t_b Signed-off-by: Yurii <yurii@skymind.io> * Added test for specialConcat. * Added tests for memcpy/set routines. * Fixed pullRow cuda test. * Added pullRow test. * Added average test. * - correct typo in NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op...) Signed-off-by: Yurii <yurii@skymind.io> * - debugging and fixing cuda tests in JavaInteropTests file Signed-off-by: Yurii <yurii@skymind.io> * - correct some tests Signed-off-by: Yurii <yurii@skymind.io> * Added test for shuffle. * Fixed ops declarations. * Restored omp and added shuffle test. * Added convertTypes test. * Added tests for execRandom. Eliminated usage of RandomBuffer with NativeOps. * Added sort tests. * Added tests for execCustomOp. * - further debuging and fixing tests terminated with crash Signed-off-by: Yurii <yurii@skymind.io> * Added tests for calculateOutputShapes. * Addded Benchmarks test. * Commented benchmark tests. * change assertion Signed-off-by: raver119 <raver119@gmail.com> * Added tests for apply_sgd op. Added cpu helper for that op. * Implement cuda helper for aplly_sgd op. Fixed tests for NativeOps. * Added test for assign broadcastable. * Added tests for assign_bp op. * Added tests for axpy op. * - assign/execScalar/execTransformAny signature change - minor test fix Signed-off-by: raver119 <raver119@gmail.com> * Fixed axpy op. * meh Signed-off-by: raver119 <raver119@gmail.com> * - fix tests for nativeOps::concat Signed-off-by: Yurii <yurii@skymind.io> * sequential transform/scalar Signed-off-by: raver119 <raver119@gmail.com> * allow nested parallelism Signed-off-by: raver119 <raver119@gmail.com> * assign_bp leak fix Signed-off-by: raver119 <raver119@gmail.com> * block setRNG fix Signed-off-by: raver119 <raver119@gmail.com> * enable parallelism by default Signed-off-by: raver119 <raver119@gmail.com> * enable nested parallelism by default Signed-off-by: raver119 <raver119@gmail.com> * Added cuda implementation for row_count helper. * Added implementation for tnse gains op helper. * - take into account possible situations when input arrays are empty in reduce_ cuda stuff Signed-off-by: Yurii <yurii@skymind.io> * Implemented tsne/edge_forces op cuda-based helper. Parallelized cpu-based helper for edge_forces. * Added kernel for tsne/symmetrized op heleper. * Implementation of tsne/symmetrized op cuda helper. Working edition. * Eliminated waste printfs. * Added test for broadcastgradientargs op. * host-only fallback for empty reduce float Signed-off-by: raver119 <raver119@gmail.com> * - some tests fixes Signed-off-by: Yurii <yurii@skymind.io> * - correct the rest of reduce_ stuff Signed-off-by: Yurii <yurii@skymind.io> * - further correction of reduce_ stuff Signed-off-by: Yurii <yurii@skymind.io> * Added test for Cbow op. Also added cuda implementation for cbow helpers. * - improve code of stack operation for scalar case Signed-off-by: Yurii <yurii@skymind.io> * - provide cuda kernel for gatherND operation Signed-off-by: Yurii <yurii@skymind.io> * Implementation of cbow helpers with cuda kernels. * minor tests tweaks Signed-off-by: raver119 <raver119@gmail.com> * minor tests tweaks Signed-off-by: raver119 <raver119@gmail.com> * - further correction of cuda stuff Signed-off-by: Yurii <yurii@skymind.io> * Implementatation of cbow op helper with cuda kernels. Working edition. * Skip random testing for cudablas case. * lstmBlockCell context fix Signed-off-by: raver119 <raver119@gmail.com> * Added tests for ELU and ELU_BP ops. * Added tests for eq_scalar, gt_scalar, gte_scalar and lte_scalar ops. * Added tests for neq_scalar. * Added test for noop. * - further work on clipbynorm_bp Signed-off-by: Yurii <yurii@skymind.io> * - get rid of concat op call, use instead direct concat helper call Signed-off-by: Yurii <yurii@skymind.io> * lstmBlockCell context fix Signed-off-by: raver119 <raver119@gmail.com> * Added tests for lrelu and lrelu_bp. * Added tests for selu and selu_bp. * Fixed lrelu derivative helpers. * - some corrections in lstm Signed-off-by: Yurii <yurii@skymind.io> * operator * result shape fix Signed-off-by: raver119 <raver119@gmail.com> * - correct typo in lstmCell Signed-off-by: Yurii <yurii@skymind.io> * few tests fixed Signed-off-by: raver119 <raver119@gmail.com> * CUDA inverse broadcast bool fix Signed-off-by: raver119 <raver119@gmail.com> * disable MMAP test for CUDA Signed-off-by: raver119 <raver119@gmail.com> * BooleanOp syncToDevice Signed-off-by: raver119 <raver119@gmail.com> * meh Signed-off-by: raver119 <raver119@gmail.com> * additional data types for im2col/col2im Signed-off-by: raver119 <raver119@gmail.com> * Added test for firas_sparse op. * one more RandomBuffer test excluded Signed-off-by: raver119 <raver119@gmail.com> * Added tests for flatten op. * Added test for Floor op. * bunch of tests fixed Signed-off-by: raver119 <raver119@gmail.com> * mmulDot tests fixed Signed-off-by: raver119 <raver119@gmail.com> * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * Implemented floordiv_bp op and tests. * Fixed scalar case with cuda implementation for bds. * - work on cuda kernel for clip_by_norm backprop op is completed Signed-off-by: Yurii <yurii@skymind.io> * Eliminate cbow crach. * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * Eliminated abortion with batched nlp test. * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * Fixed shared flag initializing. * disabled bunch of cpu workspaces tests Signed-off-by: raver119 <raver119@gmail.com> * scalar operators fix: missing registerSpecialUse call Signed-off-by: raver119 <raver119@gmail.com> * Fixed logdet for cuda and tests. * - correct clipBynorm_bp Signed-off-by: Yurii <yurii@skymind.io> * Fixed crop_and_resize shape datatype. * - correct some mmul tests Signed-off-by: Yurii <yurii@skymind.io> * build fix Signed-off-by: raver119 <raver119@gmail.com> * exclude two methods for JNI Signed-off-by: raver119 <raver119@gmail.com> * exclude two methods for JNI Signed-off-by: raver119 <raver119@gmail.com> * exclude two methods for JNI (#97) Signed-off-by: raver119 <raver119@gmail.com> * temporary stack fix Signed-off-by: raver119 <raver119@gmail.com> * couple of legacy groups reorganized into separate compialtion units Signed-off-by: raver119 <raver119@gmail.com> * wrong include Signed-off-by: raver119 <raver119@gmail.com> * wrong include Signed-off-by: raver119 <raver119@gmail.com> * ReductionLoops_float split Signed-off-by: raver119 <raver119@gmail.com> * maximum Signed-off-by: raver119 <raver119@gmail.com> * some more rearrangements Signed-off-by: raver119 <raver119@gmail.com> * spare ifdef Signed-off-by: raver119 <raver119@gmail.com> * mirror pad Signed-off-by: raver119 <raver119@gmail.com> * - reduce_float split - mcmodel Signed-off-by: raver119 <raver119@gmail.com> * bad include fix Signed-off-by: raver119 <raver119@gmail.com> * norelax Signed-off-by: raver119 <raver119@gmail.com> * norelax Signed-off-by: raver119 <raver119@gmail.com> * norelax Signed-off-by: raver119 <raver119@gmail.com> * norelax Signed-off-by: raver119 <raver119@gmail.com> * norelax Signed-off-by: raver119 <raver119@gmail.com> * norelax gone Signed-off-by: raver119 <raver119@gmail.com> * get back sm Signed-off-by: raver119 <raver119@gmail.com> * fix couple of tests for msvc Signed-off-by: raver119 <raver119@gmail.com> * fix couple of tests for msvc Signed-off-by: raver119 <raver119@gmail.com> * compress-all Signed-off-by: raver119 <raver119@gmail.com> * reduced arch list Signed-off-by: raver119 <raver119@gmail.com> * compress-all Signed-off-by: raver119 <raver119@gmail.com> * reduced arch list Signed-off-by: raver119 <raver119@gmail.com> * all compute capabilities option for tests Signed-off-by: raver119 <raver119@gmail.com>master
parent
c78f5a8225
commit
24e43e9856
|
@ -99,7 +99,7 @@ elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Intel")
|
||||||
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
|
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
|
||||||
# using Visual Studio C++
|
# using Visual Studio C++
|
||||||
|
|
||||||
set( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc")
|
set( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /w")
|
||||||
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||||
# using GCC
|
# using GCC
|
||||||
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}")
|
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}")
|
||||||
|
@ -118,16 +118,6 @@ if(!CUDA_BLAS)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# TODO: get rid of this once problem confirmed solved
|
|
||||||
#if (APPLE)
|
|
||||||
# if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
|
||||||
# if ("${CMAKE_C_COMPILER_VERSION}" VERSION_GREATER 6.0 OR "${CMAKE_C_COMPILER_VERSION}" VERSION_EQUAL 6.0)
|
|
||||||
# SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wa,-mavx512f,-mavx512vl,-mavx512bw,-mavx512dq,-mavx512cd ")
|
|
||||||
# endif()
|
|
||||||
# endif()
|
|
||||||
#endif()
|
|
||||||
|
|
||||||
|
|
||||||
if(CUDA_BLAS)
|
if(CUDA_BLAS)
|
||||||
message("Build cublas")
|
message("Build cublas")
|
||||||
find_package(CUDA)
|
find_package(CUDA)
|
||||||
|
@ -173,32 +163,32 @@ if(CUDA_BLAS)
|
||||||
if(CUDA_VERSION VERSION_GREATER "9.2") # cuda 10
|
if(CUDA_VERSION VERSION_GREATER "9.2") # cuda 10
|
||||||
if ("${COMPUTE}" STREQUAL "all")
|
if ("${COMPUTE}" STREQUAL "all")
|
||||||
if (APPLE)
|
if (APPLE)
|
||||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61)
|
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
|
||||||
else()
|
else()
|
||||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75)
|
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70)
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 --fatbin -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||||
endif()
|
endif()
|
||||||
elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9
|
elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9
|
||||||
if ("${COMPUTE}" STREQUAL "all")
|
if ("${COMPUTE}" STREQUAL "all")
|
||||||
if (APPLE)
|
if (APPLE)
|
||||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61)
|
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
|
||||||
else()
|
else()
|
||||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61)
|
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||||
endif()
|
endif()
|
||||||
elseif (CUDA_VERSION VERSION_GREATER "7.5") # cuda 8.0
|
elseif (CUDA_VERSION VERSION_GREATER "7.5") # cuda 8.0
|
||||||
if ("${COMPUTE}" STREQUAL "all")
|
if ("${COMPUTE}" STREQUAL "all")
|
||||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61)
|
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
|
||||||
else()
|
else()
|
||||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
if ("${COMPUTE}" STREQUAL "all")
|
if ("${COMPUTE}" STREQUAL "all")
|
||||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52)
|
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 )
|
||||||
else()
|
else()
|
||||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||||
endif()
|
endif()
|
||||||
|
@ -220,7 +210,7 @@ if(CUDA_BLAS)
|
||||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70)
|
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70)
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=compute_${COMPUTE})
|
||||||
endif()
|
endif()
|
||||||
elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9
|
elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9
|
||||||
if ("${COMPUTE}" STREQUAL "all")
|
if ("${COMPUTE}" STREQUAL "all")
|
||||||
|
|
|
@ -40,7 +40,7 @@ namespace nd4j {
|
||||||
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams ), REDUCE_FLOAT_OPS);
|
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams ), REDUCE_FLOAT_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "ReductionLoops.hpp"
|
||||||
|
#include <pointercast.h>
|
||||||
|
#include <types/types.h>
|
||||||
|
|
||||||
|
using namespace simdOps;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
|
||||||
|
template<typename X, typename Z>
|
||||||
|
template <typename OpType>
|
||||||
|
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams) {
|
||||||
|
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Y>
|
||||||
|
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
|
||||||
|
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
|
||||||
|
Nd4jLong *tadOffsets, Y *extraParams) {
|
||||||
|
|
||||||
|
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams ), REDUCE_FLOAT_OPS);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "ReductionLoops.hpp"
|
||||||
|
#include <pointercast.h>
|
||||||
|
#include <types/types.h>
|
||||||
|
|
||||||
|
using namespace simdOps;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
|
||||||
|
template<typename X, typename Z>
|
||||||
|
template <typename OpType>
|
||||||
|
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams) {
|
||||||
|
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Y>
|
||||||
|
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
|
||||||
|
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
|
||||||
|
Nd4jLong *tadOffsets, Y *extraParams) {
|
||||||
|
|
||||||
|
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams ), REDUCE_FLOAT_OPS);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_2);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "ReductionLoops.hpp"
|
||||||
|
#include <pointercast.h>
|
||||||
|
#include <types/types.h>
|
||||||
|
|
||||||
|
using namespace simdOps;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
|
||||||
|
template<typename X, typename Z>
|
||||||
|
template <typename OpType>
|
||||||
|
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams) {
|
||||||
|
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Y>
|
||||||
|
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
|
||||||
|
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
|
||||||
|
Nd4jLong *tadOffsets, Y *extraParams) {
|
||||||
|
|
||||||
|
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams ), REDUCE_FLOAT_OPS);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_3);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -220,7 +220,7 @@ namespace functions {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/*
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_0);
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_0);
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_1);
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_1);
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_2);
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_2);
|
||||||
|
@ -231,5 +231,6 @@ namespace functions {
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_7);
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_7);
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_8);
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_8);
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_9);
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_9);
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../broadcasting.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace broadcast {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_0);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../broadcasting.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace broadcast {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_1);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../broadcasting.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace broadcast {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_2);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../broadcasting.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace broadcast {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_3);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../broadcasting.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace broadcast {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_4);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../broadcasting.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace broadcast {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_5);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../broadcasting.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace broadcast {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_6);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../broadcasting.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace broadcast {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_7);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../broadcasting.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace broadcast {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_8);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../broadcasting.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace broadcast {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_9);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../pairwise.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace pairwise_transforms {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_0);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../pairwise.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace pairwise_transforms {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_1);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../pairwise.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace pairwise_transforms {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_2);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../pairwise.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace pairwise_transforms {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_3);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../pairwise.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace pairwise_transforms {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_4);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../pairwise.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace pairwise_transforms {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_5);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../pairwise.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace pairwise_transforms {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_6);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../pairwise.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace pairwise_transforms {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_7);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../pairwise.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace pairwise_transforms {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_8);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../pairwise.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace pairwise_transforms {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_9);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../reduce3.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace reduce3 {
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_0);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../reduce3.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace reduce3 {
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_1);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../reduce3.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace reduce3 {
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_2);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../reduce3.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace reduce3 {
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_3);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../reduce/reduce_float.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace reduce {
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_0);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../reduce/reduce_float.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace reduce {
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_1);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../reduce/reduce_float.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace reduce {
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_2);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../reduce/reduce_float.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace reduce {
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_3);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../scalar.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace scalar {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_0);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../scalar.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace scalar {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_1);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../scalar.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace scalar {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_2);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../scalar.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace scalar {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_3);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../scalar.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace scalar {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_4);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../scalar.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace scalar {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_5);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../scalar.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace scalar {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_6);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../scalar.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace scalar {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_7);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../scalar.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace scalar {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_8);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../../scalar.chpp"
|
||||||
|
|
||||||
|
namespace functions {
|
||||||
|
namespace scalar {
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_9);
|
||||||
|
}
|
||||||
|
}
|
|
@ -106,7 +106,7 @@ void __host__ PairWiseTransform<X,Y,Z>::executeCudaShaped(dim3& launchDims, cuda
|
||||||
DISPATCH_BY_OPNUM_TTT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_TRANSFORM_OPS);
|
DISPATCH_BY_OPNUM_TTT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_TRANSFORM_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_0);
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_0);
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_1);
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_1);
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_2);
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_2);
|
||||||
|
@ -117,6 +117,7 @@ void __host__ PairWiseTransform<X,Y,Z>::executeCudaShaped(dim3& launchDims, cuda
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_7);
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_7);
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_8);
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_8);
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_9);
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_9);
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -304,7 +304,7 @@ __device__ void initializeShared(X *extraParams, X **sPartials, int sMemSize) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES);
|
//BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -559,7 +559,7 @@ __host__ void Reduce3<X,Z>::execScalar(dim3 launchDims, cudaStream_t *stream,
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES);
|
//BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -165,18 +165,6 @@ void ScalarTransform<X,Y,Z>::executeCudaAlongDimension(dim3& launchDims, cudaStr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_0);
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_1);
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_2);
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_3);
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_4);
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_5);
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_6);
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_7);
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_8);
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_9);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,129 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include<ops/declarable/helpers/transforms.h>
|
||||||
|
#include <array/ResultSet.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <numeric>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include <helpers/TAD.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
#include <ConstantTadHelper.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ static void concatCuda(const int numOfArrs, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo) {
|
||||||
|
|
||||||
|
__shared__ int arrIdx, blocksPerArr;
|
||||||
|
__shared__ T *x, *z;
|
||||||
|
__shared__ Nd4jLong *zShapeInfo, *xShapeInfo, arrLen, arrLenPerBlock, start, end;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
blocksPerArr = (gridDim.x + numOfArrs - 1) / numOfArrs; // ceil
|
||||||
|
arrIdx = blockIdx.x / blocksPerArr;
|
||||||
|
|
||||||
|
x = reinterpret_cast<T*>(reinterpret_cast<void**>(pVx)[arrIdx]);
|
||||||
|
z = reinterpret_cast<T*>(reinterpret_cast<void**>(pVz)[arrIdx]);
|
||||||
|
xShapeInfo = reinterpret_cast<Nd4jLong**>(pxShapeInfo)[arrIdx];
|
||||||
|
zShapeInfo = reinterpret_cast<Nd4jLong**>(pzShapeInfo)[arrIdx];
|
||||||
|
arrLen = shape::length(xShapeInfo);
|
||||||
|
|
||||||
|
arrLenPerBlock = (arrLen + blocksPerArr - 1) / blocksPerArr; // ceil
|
||||||
|
|
||||||
|
start = (blockIdx.x % blocksPerArr) * arrLenPerBlock;
|
||||||
|
end = (start + arrLenPerBlock) > arrLen ? arrLen : (start + arrLenPerBlock);
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (Nd4jLong i = start + threadIdx.x; i < end; i += blockDim.x)
|
||||||
|
z[shape::getIndexOffset(i, zShapeInfo, arrLen)] = x[shape::getIndexOffset(i, xShapeInfo, arrLen)];
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__host__ static void concatCudaLauncher(const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo) {
|
||||||
|
|
||||||
|
concatCuda<T><<<512, 256, 1024, *stream>>>(numOfArrs, pVx, pxShapeInfo, pVz, pzShapeInfo);
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||||
|
|
||||||
|
const int numOfArrs = inArrs.size();
|
||||||
|
for(int i = 0; i < numOfArrs; ++i)
|
||||||
|
if(!inArrs[i]->isActualOnDeviceSide()) inArrs[i]->syncToDevice();
|
||||||
|
|
||||||
|
const int rank = inArrs[0]->rankOf();
|
||||||
|
const int rank2 = 2*rank;
|
||||||
|
std::vector<std::vector<Nd4jLong>> indices(numOfArrs, std::vector<Nd4jLong>(rank2,0));
|
||||||
|
|
||||||
|
// take into account indices for first array
|
||||||
|
indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis);
|
||||||
|
|
||||||
|
// loop through the rest of input arrays
|
||||||
|
for(int i = 1; i < numOfArrs; ++i) {
|
||||||
|
indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from
|
||||||
|
indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis); // index end with (excluding)
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<NDArray*> outSubArrs(numOfArrs);
|
||||||
|
for(int i = 0; i < numOfArrs; ++i)
|
||||||
|
outSubArrs[i] = new NDArray(output(indices[i], true));
|
||||||
|
|
||||||
|
// prepare arrays of pointers on buffers and shapes
|
||||||
|
std::vector<void*> hOutBuffers(numOfArrs), hInBuffers(numOfArrs);
|
||||||
|
std::vector<Nd4jLong*> hOutShapeInfo(numOfArrs), hInShapeInfo(numOfArrs);
|
||||||
|
for(int i = 0; i < numOfArrs; ++i) {
|
||||||
|
hOutBuffers[i] = outSubArrs[i]->getSpecialBuffer();
|
||||||
|
hInBuffers[i] = inArrs[i]->getSpecialBuffer();
|
||||||
|
hOutShapeInfo[i] = outSubArrs[i]->getSpecialShapeInfo();
|
||||||
|
hInShapeInfo[i] = inArrs[i]->getSpecialShapeInfo();
|
||||||
|
}
|
||||||
|
|
||||||
|
// allocate and copy all buffers and shapes arrays to global memory
|
||||||
|
PointersManager manager(context, "helpers::concat");
|
||||||
|
void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*));
|
||||||
|
void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*));
|
||||||
|
void* dInShapeInfo = manager.replicatePointer(hInShapeInfo.data(), hInShapeInfo.size() * sizeof(Nd4jLong*));
|
||||||
|
void* dOutShapeInfo = manager.replicatePointer(hOutShapeInfo.data(), hOutShapeInfo.size() * sizeof(Nd4jLong*));
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (numOfArrs, context->getCudaStream(), dInBuffers, dInShapeInfo, dOutBuffers, dOutShapeInfo), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
|
||||||
|
for(int i = 0; i < numOfArrs; ++i)
|
||||||
|
delete outSubArrs[i];
|
||||||
|
|
||||||
|
for(int i = 0; i < numOfArrs; ++i)
|
||||||
|
inArrs[i]->tickReadHost();
|
||||||
|
|
||||||
|
output.tickWriteDevice();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,147 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include<ops/declarable/helpers/transforms.h>
|
||||||
|
#include <array/ResultSet.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <numeric>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include <helpers/TAD.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
#include <ConstantTadHelper.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
// x - input, y - indices, z - output
|
||||||
|
template<typename X, typename Y>
|
||||||
|
__global__ static void gatherNDCuda(const void *vx, const Nd4jLong *xShapeInfo,
|
||||||
|
const void *vy, const Nd4jLong *yShapeInfo,
|
||||||
|
void *vz, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
const auto x = reinterpret_cast<const X*>(vx);
|
||||||
|
const auto y = reinterpret_cast<const Y*>(vy);
|
||||||
|
auto z = reinterpret_cast<X*>(vz);
|
||||||
|
|
||||||
|
__shared__ int xRank, yRank, zRank, maxRank, yLastDim;
|
||||||
|
__shared__ Nd4jLong zLen, totalThreads, *sharedMem;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
|
|
||||||
|
xRank = shape::rank(xShapeInfo);
|
||||||
|
yRank = shape::rank(yShapeInfo);
|
||||||
|
zRank = shape::rank(zShapeInfo);
|
||||||
|
maxRank = nd4j::math::nd4j_max<int>(yRank, nd4j::math::nd4j_max<int>(xRank, zRank));
|
||||||
|
|
||||||
|
zLen = shape::length(zShapeInfo);
|
||||||
|
yLastDim = yShapeInfo[yRank];
|
||||||
|
|
||||||
|
totalThreads = gridDim.x * blockDim.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
auto coord = sharedMem + threadIdx.x * maxRank;
|
||||||
|
|
||||||
|
Nd4jLong *zCoordStart, *xCoordStart;
|
||||||
|
|
||||||
|
if(yLastDim == xRank) {
|
||||||
|
zCoordStart = coord;
|
||||||
|
xCoordStart = coord;
|
||||||
|
}
|
||||||
|
if(zRank >= xRank) {
|
||||||
|
zCoordStart = coord;
|
||||||
|
xCoordStart = coord + zRank - xRank;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
zCoordStart = coord + xRank - zRank;
|
||||||
|
xCoordStart = coord;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
||||||
|
|
||||||
|
shape::index2coords(zRank, zShapeInfo + 1, i, zLen, zCoordStart);
|
||||||
|
|
||||||
|
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + zRank + 1, zCoordStart, zRank);
|
||||||
|
|
||||||
|
// last y coordinate
|
||||||
|
int coordToRestore;
|
||||||
|
if(yLastDim != xRank)
|
||||||
|
coordToRestore = static_cast<int>(zCoordStart[yRank - 1]);
|
||||||
|
|
||||||
|
zCoordStart[yRank - 1] = 0; // last y coordinate
|
||||||
|
const auto yOffset = shape::getOffset(0, yShapeInfo + 1, yShapeInfo + yRank + 1, zCoordStart, yRank);
|
||||||
|
|
||||||
|
//restore z coordinate
|
||||||
|
if(yLastDim != xRank)
|
||||||
|
zCoordStart[yRank - 1] = coordToRestore;
|
||||||
|
|
||||||
|
// construct coordinates for x
|
||||||
|
for(uint j = 0; j < yLastDim; ++j)
|
||||||
|
xCoordStart[j] = y[yOffset + j * yShapeInfo[2 * yRank]]; // last stride
|
||||||
|
|
||||||
|
const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + xRank + 1, xCoordStart, xRank);
|
||||||
|
|
||||||
|
z[zOffset] = x[xOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename X, typename Y>
|
||||||
|
static void gatherNDCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
|
||||||
|
const void *vx, const Nd4jLong *xShapeInfo,
|
||||||
|
const void *vy, const Nd4jLong *yShapeInfo,
|
||||||
|
void *vz, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
gatherNDCuda<X,Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo);
|
||||||
|
}
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void gatherNDCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo), LIBND4J_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void gatherND(nd4j::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output) {
|
||||||
|
|
||||||
|
const int maxRank = nd4j::math::nd4j_max<int>(indices.rankOf(), nd4j::math::nd4j_max<int>(input.rankOf(), output.rankOf()));
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS;
|
||||||
|
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
const int sharedMem = 8 * threadsPerBlock * maxRank + 128;
|
||||||
|
|
||||||
|
const auto xType = input.dataType();
|
||||||
|
const auto yType = indices.dataType();
|
||||||
|
|
||||||
|
PointersManager manager(context, "gatherND");
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&input, &indices});
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, yType, gatherNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), LIBND4J_TYPES, INTEGER_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&output}, {&input, &indices});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,118 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author GS <sgazeos@gmail.com>
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/legacy_helpers.h>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
linkage void reluDerivative__(NDArray* theFirst, NDArray* theSecond) {
|
||||||
|
auto functor = LAMBDA_TT(x, y){
|
||||||
|
return x > (T) 0.f ? y : T(0.f);
|
||||||
|
};
|
||||||
|
|
||||||
|
theFirst->applyPairwiseLambda(theSecond, functor, nullptr);
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void reluDerivative__, (NDArray* input, NDArray* epsilon), FLOAT_TYPES);
|
||||||
|
|
||||||
|
void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond) {
|
||||||
|
BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative__, (theFirst, theSecond), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
linkage void reluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
||||||
|
auto functor = LAMBDA_TT(x, y){
|
||||||
|
return x > (T)0.f ? y : T(0.f);
|
||||||
|
};
|
||||||
|
|
||||||
|
input->applyPairwiseLambda(epsilon, functor, output);
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void reluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES);
|
||||||
|
|
||||||
|
void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
||||||
|
BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
linkage void relu6Derivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
||||||
|
auto functor = LAMBDA_TT(x, y){
|
||||||
|
return x > (T)0.f && x < (T)6.f? y : T(0.f);
|
||||||
|
};
|
||||||
|
|
||||||
|
input->applyPairwiseLambda(epsilon, functor, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void relu6Derivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES);
|
||||||
|
|
||||||
|
void relu6Derivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
||||||
|
BUILD_SINGLE_SELECTOR(theFirst->dataType(), relu6Derivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
linkage void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
||||||
|
auto functor = LAMBDA_TT(x, y){
|
||||||
|
return x >= (T)0.f? y : T(0.f);
|
||||||
|
};
|
||||||
|
|
||||||
|
input->applyPairwiseLambda(epsilon, functor, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void leakyReluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES);
|
||||||
|
|
||||||
|
void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
||||||
|
BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
linkage void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
||||||
|
auto functor = LAMBDA_TT(x, y){
|
||||||
|
return y * nd4j::math::nd4j_eluderivative<T,T>(x);
|
||||||
|
};
|
||||||
|
|
||||||
|
input->applyPairwiseLambda(epsilon, functor, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void eluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES);
|
||||||
|
|
||||||
|
void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
||||||
|
BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
linkage void seluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
||||||
|
auto functor = LAMBDA_TT(x, y){
|
||||||
|
return y * simdOps::SELUDerivative<T>::op(x, nullptr);
|
||||||
|
};
|
||||||
|
|
||||||
|
input->applyPairwiseLambda(epsilon, functor, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void seluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES);
|
||||||
|
|
||||||
|
void seluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
||||||
|
BUILD_SINGLE_SELECTOR(theFirst->dataType(), seluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,93 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author GS <sgazeos@gmail.com>
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/legacy_helpers.h>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
linkage void tanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
||||||
|
auto functor = LAMBDA_TT(x, y){
|
||||||
|
T th = nd4j::math::nd4j_tanh<T,T>(x);
|
||||||
|
return y * ((T)1.0f - (th * th));
|
||||||
|
};
|
||||||
|
|
||||||
|
input->applyPairwiseLambda(epsilon, functor, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void tanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES);
|
||||||
|
|
||||||
|
void tanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
||||||
|
BUILD_SINGLE_SELECTOR(theFirst->dataType(), tanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
// return static_cast<X>(d2) * simdOps::HardTanhDerivative<X>::op(d1, nullptr);
|
||||||
|
template <typename T>
|
||||||
|
linkage void hardTanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
||||||
|
auto functor = LAMBDA_TT(x, y){
|
||||||
|
T th = nd4j::math::nd4j_tanh<T,T>(x);
|
||||||
|
return y * simdOps::HardTanhDerivative<T>::op(x, nullptr);
|
||||||
|
};
|
||||||
|
|
||||||
|
input->applyPairwiseLambda(epsilon, functor, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void hardTanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES);
|
||||||
|
|
||||||
|
void hardTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
||||||
|
BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
linkage void rationalTanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
||||||
|
auto functor = LAMBDA_TT(x, y){
|
||||||
|
return y * simdOps::RationalTanhDerivative<T>::op(x, nullptr);
|
||||||
|
};
|
||||||
|
|
||||||
|
input->applyPairwiseLambda(epsilon, functor, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void rationalTanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES);
|
||||||
|
|
||||||
|
void rationalTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
||||||
|
BUILD_SINGLE_SELECTOR(theFirst->dataType(), rationalTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
linkage void rectifiedTanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
||||||
|
auto functor = LAMBDA_TT(x, y){
|
||||||
|
return x > (T) 0.0f ? y * (nd4j::math::nd4j_tanhderivative<T,T>(x)) : (T) 0.0f;
|
||||||
|
};
|
||||||
|
|
||||||
|
input->applyPairwiseLambda(epsilon, functor, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void rectifiedTanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES);
|
||||||
|
|
||||||
|
void rectifiedTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
||||||
|
BUILD_SINGLE_SELECTOR(theFirst->dataType(), rectifiedTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -25,93 +25,6 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
template <typename T>
|
|
||||||
linkage void reluDerivative__(NDArray* theFirst, NDArray* theSecond) {
|
|
||||||
auto functor = LAMBDA_TT(x, y){
|
|
||||||
return x > (T) 0.f ? y : T(0.f);
|
|
||||||
};
|
|
||||||
|
|
||||||
theFirst->applyPairwiseLambda(theSecond, functor, nullptr);
|
|
||||||
}
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void reluDerivative__, (NDArray* input, NDArray* epsilon), FLOAT_TYPES);
|
|
||||||
|
|
||||||
void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond) {
|
|
||||||
BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative__, (theFirst, theSecond), FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
linkage void reluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
|
||||||
auto functor = LAMBDA_TT(x, y){
|
|
||||||
return x > (T)0.f ? y : T(0.f);
|
|
||||||
};
|
|
||||||
|
|
||||||
input->applyPairwiseLambda(epsilon, functor, output);
|
|
||||||
}
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void reluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES);
|
|
||||||
|
|
||||||
void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
|
||||||
BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
linkage void relu6Derivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
|
||||||
auto functor = LAMBDA_TT(x, y){
|
|
||||||
return x > (T)0.f && x < (T)6.f? y : T(0.f);
|
|
||||||
};
|
|
||||||
|
|
||||||
input->applyPairwiseLambda(epsilon, functor, output);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void relu6Derivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES);
|
|
||||||
|
|
||||||
void relu6Derivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
|
||||||
BUILD_SINGLE_SELECTOR(theFirst->dataType(), relu6Derivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
linkage void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
|
||||||
auto functor = LAMBDA_TT(x, y){
|
|
||||||
return x >= (T)0.f? y : T(0.f);
|
|
||||||
};
|
|
||||||
|
|
||||||
input->applyPairwiseLambda(epsilon, functor, output);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void leakyReluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES);
|
|
||||||
|
|
||||||
void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
|
||||||
BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
linkage void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
|
||||||
auto functor = LAMBDA_TT(x, y){
|
|
||||||
return y * nd4j::math::nd4j_eluderivative<T,T>(x);
|
|
||||||
};
|
|
||||||
|
|
||||||
input->applyPairwiseLambda(epsilon, functor, output);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void eluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES);
|
|
||||||
|
|
||||||
void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
|
||||||
BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
linkage void seluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
|
||||||
auto functor = LAMBDA_TT(x, y){
|
|
||||||
return y * simdOps::SELUDerivative<T>::op(x, nullptr);
|
|
||||||
};
|
|
||||||
|
|
||||||
input->applyPairwiseLambda(epsilon, functor, output);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void seluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES);
|
|
||||||
|
|
||||||
void seluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
|
||||||
BUILD_SINGLE_SELECTOR(theFirst->dataType(), seluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
linkage void cubeDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
linkage void cubeDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
||||||
|
@ -180,70 +93,6 @@ namespace helpers {
|
||||||
BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropyGrad_, (logits, labels, output), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropyGrad_, (logits, labels, output), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename T>
|
|
||||||
linkage void tanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
|
||||||
auto functor = LAMBDA_TT(x, y){
|
|
||||||
T th = nd4j::math::nd4j_tanh<T,T>(x);
|
|
||||||
return y * ((T)1.0f - (th * th));
|
|
||||||
};
|
|
||||||
|
|
||||||
input->applyPairwiseLambda(epsilon, functor, output);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void tanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES);
|
|
||||||
|
|
||||||
void tanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
|
||||||
BUILD_SINGLE_SELECTOR(theFirst->dataType(), tanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
// return static_cast<X>(d2) * simdOps::HardTanhDerivative<X>::op(d1, nullptr);
|
|
||||||
template <typename T>
|
|
||||||
linkage void hardTanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
|
||||||
auto functor = LAMBDA_TT(x, y){
|
|
||||||
T th = nd4j::math::nd4j_tanh<T,T>(x);
|
|
||||||
return y * simdOps::HardTanhDerivative<T>::op(x, nullptr);
|
|
||||||
};
|
|
||||||
|
|
||||||
input->applyPairwiseLambda(epsilon, functor, output);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void hardTanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES);
|
|
||||||
|
|
||||||
void hardTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
|
||||||
BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
linkage void rationalTanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
|
||||||
auto functor = LAMBDA_TT(x, y){
|
|
||||||
return y * simdOps::RationalTanhDerivative<T>::op(x, nullptr);
|
|
||||||
};
|
|
||||||
|
|
||||||
input->applyPairwiseLambda(epsilon, functor, output);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void rationalTanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES);
|
|
||||||
|
|
||||||
void rationalTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
|
||||||
BUILD_SINGLE_SELECTOR(theFirst->dataType(), rationalTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
linkage void rectifiedTanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
|
||||||
auto functor = LAMBDA_TT(x, y){
|
|
||||||
return x > (T) 0.0f ? y * (nd4j::math::nd4j_tanhderivative<T,T>(x)) : (T) 0.0f;
|
|
||||||
};
|
|
||||||
|
|
||||||
input->applyPairwiseLambda(epsilon, functor, output);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void rectifiedTanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES);
|
|
||||||
|
|
||||||
void rectifiedTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
|
||||||
BUILD_SINGLE_SELECTOR(theFirst->dataType(), rectifiedTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
// X f = (X) 1.0f + nd4j::math::nd4j_abs<X>(d1);
|
// X f = (X) 1.0f + nd4j::math::nd4j_abs<X>(d1);
|
||||||
// return (X) d2 * ((X) 1.0f / (f * f));
|
// return (X) d2 * ((X) 1.0f / (f * f));
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,114 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author sgazeos@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#include <NDArray.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void maximumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) {
|
||||||
|
|
||||||
|
auto lambdaX = LAMBDA_TTT(_e, _x, _y) {
|
||||||
|
return _x >= _y ? _e : (T) 0.;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto lambdaY = LAMBDA_TTT(_e, _x, _y) {
|
||||||
|
return _x <= _y ? _e : (T) 0.;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
if (x->isSameShape(y)) {
|
||||||
|
// PWT case case
|
||||||
|
|
||||||
|
// X gradient
|
||||||
|
epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX);
|
||||||
|
|
||||||
|
// Y gradient
|
||||||
|
epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY);
|
||||||
|
|
||||||
|
} else if (y->isScalar()) {
|
||||||
|
T s = y->e<T>(0);
|
||||||
|
auto lambdaS = LAMBDA_TT(_e, _x, s) {
|
||||||
|
return _x >= s ? _e : (T) 0.;
|
||||||
|
};
|
||||||
|
|
||||||
|
// scalar case
|
||||||
|
auto tmp = epsNext->reduceNumber(reduce::Sum);
|
||||||
|
if (x <= y)
|
||||||
|
gradY->assign(tmp);
|
||||||
|
else
|
||||||
|
gradY->assign(0.0f);
|
||||||
|
|
||||||
|
epsNext->applyPairwiseLambda(x, lambdaS, gradX);
|
||||||
|
} else {
|
||||||
|
// broadcast case
|
||||||
|
|
||||||
|
// in this case we want to boost our X and Y shapes to the size of FF pass output (or epsNext, which has the same shape)
|
||||||
|
auto preX = x->dup();
|
||||||
|
auto preY = y->dup();
|
||||||
|
|
||||||
|
auto targetShape = epsNext->getShapeAsVector();
|
||||||
|
|
||||||
|
preX->tileToShape(targetShape);
|
||||||
|
preY->tileToShape(targetShape);
|
||||||
|
|
||||||
|
epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX);
|
||||||
|
epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY);
|
||||||
|
|
||||||
|
auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo());
|
||||||
|
auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo());
|
||||||
|
|
||||||
|
if (axisX.size() > 0) {
|
||||||
|
auto sum = preX->reduceAlongDimension(reduce::Sum, axisX);
|
||||||
|
gradX->assign(sum);
|
||||||
|
delete sum;
|
||||||
|
} else
|
||||||
|
gradX->assign(preX);
|
||||||
|
|
||||||
|
if (axisY.size() > 0) {
|
||||||
|
auto sum = preY->reduceAlongDimension(reduce::Sum, axisY);
|
||||||
|
gradY->assign(sum);
|
||||||
|
delete sum;
|
||||||
|
} else
|
||||||
|
gradY->assign(preY);
|
||||||
|
|
||||||
|
|
||||||
|
delete preX;
|
||||||
|
delete preY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void maximumBPFunctor(nd4j::LaunchContext * context, NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) {
|
||||||
|
NDArray::prepareSpecialUse({gradX, gradY}, {x, y, epsNext});
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(x->dataType(), maximumBPFunctor_, (x, y, epsNext, gradX, gradY), NUMERIC_TYPES);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({gradX, gradY}, {x, y, epsNext});
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void maximumBPFunctor_, (NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY), NUMERIC_TYPES);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,234 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include<ops/declarable/helpers/transforms.h>
|
||||||
|
#include <array/ResultSet.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <numeric>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include <helpers/TAD.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
#include <ConstantTadHelper.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T, typename Z>
|
||||||
|
static __global__ void global_mergeMaxIndex_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
||||||
|
auto output = reinterpret_cast<Z*>(voutput);
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
||||||
|
const auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
for (Nd4jLong e = tid; e < length; e += step) {
|
||||||
|
T mVal = -DataTypeUtils::max<T>();
|
||||||
|
Z mIdx(0);
|
||||||
|
|
||||||
|
for (int i = 0; i < numArrays; i++) {
|
||||||
|
auto x = reinterpret_cast<T*>(inArrs[i]);
|
||||||
|
auto xShape = reinterpret_cast<Nd4jLong *>(inShapes[i]);
|
||||||
|
auto val = x[shape::getIndexOffset(e, xShape, length)];;
|
||||||
|
if (mVal < val)
|
||||||
|
mIdx = static_cast<Z>(e);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
output[shape::getIndexOffset(e, outputShape, length)] = mIdx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Z>
|
||||||
|
static void mergeMaxIndex_(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
|
std::vector<void *> inBuffers(inArrs.size());
|
||||||
|
std::vector<void *> inShapes(inArrs.size());
|
||||||
|
|
||||||
|
for (int e = 0; e < inArrs.size(); e++) {
|
||||||
|
inBuffers[e] = inArrs[e]->getSpecialBuffer();
|
||||||
|
inShapes[e] = inArrs[e]->getSpecialShapeInfo();
|
||||||
|
}
|
||||||
|
|
||||||
|
PointersManager manager(context, "mergeMaxIndex");
|
||||||
|
|
||||||
|
auto pInBuffers = reinterpret_cast<void **>(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void *)));
|
||||||
|
auto pInShapes = reinterpret_cast<void **>(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void *)));
|
||||||
|
auto length = output.lengthOf();
|
||||||
|
|
||||||
|
global_mergeMaxIndex_<T,Z><<<512, 512, 512, *context->getCudaStream()>>>(pInBuffers, pInShapes, (int) inArrs.size(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), length);
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
void mergeMaxIndex(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (context, inArrs, output), LIBND4J_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void mergeMaxIndex_, (nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output), LIBND4J_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void global_mergeMax_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
||||||
|
auto output = reinterpret_cast<T*>(voutput);
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
||||||
|
const auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
for (Nd4jLong e = tid; e < length; e += step) {
|
||||||
|
T mVal = -DataTypeUtils::max<T>();
|
||||||
|
|
||||||
|
for (int i = 0; i < numArrays; i++) {
|
||||||
|
auto x = reinterpret_cast<T*>(inArrs[i]);
|
||||||
|
auto xShape = reinterpret_cast<Nd4jLong *>(inShapes[i]);
|
||||||
|
auto val = x[shape::getIndexOffset(e, xShape, length)];;
|
||||||
|
if (mVal < val)
|
||||||
|
mVal = val;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
output[shape::getIndexOffset(e, outputShape, length)] = mVal;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static void mergeMax_(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
|
std::vector<void *> inBuffers(inArrs.size());
|
||||||
|
std::vector<void *> inShapes(inArrs.size());
|
||||||
|
|
||||||
|
for (int e = 0; e < inArrs.size(); e++) {
|
||||||
|
inBuffers[e] = inArrs[e]->getSpecialBuffer();
|
||||||
|
inShapes[e] = inArrs[e]->getSpecialShapeInfo();
|
||||||
|
}
|
||||||
|
|
||||||
|
PointersManager manager(context, "mergeMax");
|
||||||
|
|
||||||
|
auto pInBuffers = reinterpret_cast<void **>(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void *)));
|
||||||
|
auto pInShapes = reinterpret_cast<void **>(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void *)));
|
||||||
|
auto length = output.lengthOf();
|
||||||
|
|
||||||
|
global_mergeMax_<T><<<512, 512, 512, *context->getCudaStream()>>>(pInBuffers, pInShapes, (int) inArrs.size(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), length);
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void mergeMax_, (nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
void mergeMax(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (context, inArrs, output), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void global_mergeAvg_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
||||||
|
auto output = reinterpret_cast<T*>(voutput);
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
||||||
|
const auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
for (Nd4jLong e = tid; e < length; e += step) {
|
||||||
|
T sum(0.0f);
|
||||||
|
|
||||||
|
for (int i = 0; i < numArrays; i++) {
|
||||||
|
auto x = reinterpret_cast<T*>(inArrs[i]);
|
||||||
|
auto xShape = reinterpret_cast<Nd4jLong *>(inShapes[i]);
|
||||||
|
|
||||||
|
sum += x[shape::getIndexOffset(e, xShape, length)];
|
||||||
|
}
|
||||||
|
|
||||||
|
output[shape::getIndexOffset(e, outputShape, length)] = sum / numArrays;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static void mergeAvg_(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
|
std::vector<void *> inBuffers(inArrs.size());
|
||||||
|
std::vector<void *> inShapes(inArrs.size());
|
||||||
|
|
||||||
|
for (int e = 0; e < inArrs.size(); e++) {
|
||||||
|
inBuffers[e] = inArrs[e]->getSpecialBuffer();
|
||||||
|
inShapes[e] = inArrs[e]->getSpecialShapeInfo();
|
||||||
|
}
|
||||||
|
|
||||||
|
PointersManager manager(context, "mergeAvg");
|
||||||
|
|
||||||
|
auto pInBuffers = reinterpret_cast<void **>(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void *)));
|
||||||
|
auto pInShapes = reinterpret_cast<void **>(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void *)));
|
||||||
|
auto length = output.lengthOf();
|
||||||
|
|
||||||
|
global_mergeAvg_<T><<<512, 512, 512, *context->getCudaStream()>>>(pInBuffers, pInShapes, (int) inArrs.size(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), length);
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void mergeAvg_, (nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
void mergeAvg(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (context, inArrs, output), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void global_mergeAdd_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
||||||
|
auto output = reinterpret_cast<T*>(voutput);
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
||||||
|
const auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
for (Nd4jLong e = tid; e < length; e += step) {
|
||||||
|
T sum(0.0f);
|
||||||
|
|
||||||
|
for (int i = 0; i < numArrays; i++) {
|
||||||
|
auto x = reinterpret_cast<T*>(inArrs[i]);
|
||||||
|
auto xShape = reinterpret_cast<Nd4jLong *>(inShapes[i]);
|
||||||
|
|
||||||
|
sum += x[shape::getIndexOffset(e, xShape, length)];
|
||||||
|
}
|
||||||
|
|
||||||
|
output[shape::getIndexOffset(e, outputShape, length)] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static void mergeAdd_(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
|
std::vector<void *> inBuffers(inArrs.size());
|
||||||
|
std::vector<void *> inShapes(inArrs.size());
|
||||||
|
|
||||||
|
for (int e = 0; e < inArrs.size(); e++) {
|
||||||
|
inBuffers[e] = inArrs[e]->getSpecialBuffer();
|
||||||
|
inShapes[e] = inArrs[e]->getSpecialShapeInfo();
|
||||||
|
}
|
||||||
|
|
||||||
|
PointersManager manager(context, "mergeAdd");
|
||||||
|
|
||||||
|
auto pInBuffers = reinterpret_cast<void **>(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void *)));
|
||||||
|
auto pInShapes = reinterpret_cast<void **>(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void *)));
|
||||||
|
auto length = output.lengthOf();
|
||||||
|
|
||||||
|
global_mergeAdd_<T><<<512, 512, 512, *context->getCudaStream()>>>(pInBuffers, pInShapes, (int) inArrs.size(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), length);
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void mergeAdd_, (nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
void mergeAdd(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (context, inArrs, output), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -100,78 +100,6 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
template <typename T>
|
|
||||||
void maximumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) {
|
|
||||||
|
|
||||||
auto lambdaX = LAMBDA_TTT(_e, _x, _y) {
|
|
||||||
return _x >= _y ? _e : (T) 0.;
|
|
||||||
};
|
|
||||||
|
|
||||||
auto lambdaY = LAMBDA_TTT(_e, _x, _y) {
|
|
||||||
return _x <= _y ? _e : (T) 0.;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
if (x->isSameShape(y)) {
|
|
||||||
// PWT case case
|
|
||||||
|
|
||||||
// X gradient
|
|
||||||
epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX);
|
|
||||||
|
|
||||||
// Y gradient
|
|
||||||
epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY);
|
|
||||||
|
|
||||||
} else if (y->isScalar()) {
|
|
||||||
T s = y->e<T>(0);
|
|
||||||
auto lambdaS = LAMBDA_TT(_e, _x, s) {
|
|
||||||
return _x >= s ? _e : (T) 0.;
|
|
||||||
};
|
|
||||||
|
|
||||||
// scalar case
|
|
||||||
auto tmp = epsNext->reduceNumber(reduce::Sum);
|
|
||||||
if (x <= y)
|
|
||||||
gradY->assign(tmp);
|
|
||||||
else
|
|
||||||
gradY->assign(0.0f);
|
|
||||||
|
|
||||||
epsNext->applyPairwiseLambda(x, lambdaS, gradX);
|
|
||||||
} else {
|
|
||||||
// broadcast case
|
|
||||||
|
|
||||||
// in this case we want to boost our X and Y shapes to the size of FF pass output (or epsNext, which has the same shape)
|
|
||||||
auto preX = x->dup();
|
|
||||||
auto preY = y->dup();
|
|
||||||
|
|
||||||
auto targetShape = epsNext->getShapeAsVector();
|
|
||||||
|
|
||||||
preX->tileToShape(targetShape);
|
|
||||||
preY->tileToShape(targetShape);
|
|
||||||
|
|
||||||
epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX);
|
|
||||||
epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY);
|
|
||||||
|
|
||||||
auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo());
|
|
||||||
auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo());
|
|
||||||
|
|
||||||
if (axisX.size() > 0) {
|
|
||||||
auto sum = preX->reduceAlongDimension(reduce::Sum, axisX);
|
|
||||||
gradX->assign(sum);
|
|
||||||
delete sum;
|
|
||||||
} else
|
|
||||||
gradX->assign(preX);
|
|
||||||
|
|
||||||
if (axisY.size() > 0) {
|
|
||||||
auto sum = preY->reduceAlongDimension(reduce::Sum, axisY);
|
|
||||||
gradY->assign(sum);
|
|
||||||
delete sum;
|
|
||||||
} else
|
|
||||||
gradY->assign(preY);
|
|
||||||
|
|
||||||
|
|
||||||
delete preX;
|
|
||||||
delete preY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void minimumBPFunctor(nd4j::LaunchContext * context, NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) {
|
void minimumBPFunctor(nd4j::LaunchContext * context, NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) {
|
||||||
NDArray::prepareSpecialUse({gradX, gradY}, {x, y, epsNext});
|
NDArray::prepareSpecialUse({gradX, gradY}, {x, y, epsNext});
|
||||||
|
@ -181,15 +109,7 @@ namespace nd4j {
|
||||||
NDArray::registerSpecialUse({gradX, gradY}, {x, y, epsNext});
|
NDArray::registerSpecialUse({gradX, gradY}, {x, y, epsNext});
|
||||||
}
|
}
|
||||||
|
|
||||||
void maximumBPFunctor(nd4j::LaunchContext * context, NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) {
|
|
||||||
NDArray::prepareSpecialUse({gradX, gradY}, {x, y, epsNext});
|
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(x->dataType(), maximumBPFunctor_, (x, y, epsNext, gradX, gradY), NUMERIC_TYPES);
|
|
||||||
|
|
||||||
NDArray::registerSpecialUse({gradX, gradY}, {x, y, epsNext});
|
|
||||||
}
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void minimumBPFunctor_, (NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY), NUMERIC_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void minimumBPFunctor_, (NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY), NUMERIC_TYPES);
|
||||||
BUILD_SINGLE_TEMPLATE(template void maximumBPFunctor_, (NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY), NUMERIC_TYPES);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -0,0 +1,283 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include<ops/declarable/helpers/transforms.h>
|
||||||
|
#include <array/ResultSet.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <numeric>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include <helpers/TAD.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
#include <ConstantTadHelper.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
// x - input, y - paddings, z - output
|
||||||
|
template<typename X, typename Y>
|
||||||
|
__global__ static void padCuda(const int mode,
|
||||||
|
const void *vx, const Nd4jLong *xShapeInfo,
|
||||||
|
const void *vy, const Nd4jLong *yShapeInfo,
|
||||||
|
void *vz, const Nd4jLong *zShapeInfo,
|
||||||
|
const void *vPadVal) {
|
||||||
|
|
||||||
|
const X padVal = *reinterpret_cast<const X*>(vPadVal);
|
||||||
|
|
||||||
|
const auto x = reinterpret_cast<const X*>(vx);
|
||||||
|
const auto y = reinterpret_cast<const Y*>(vy);
|
||||||
|
auto z = reinterpret_cast<X*>(vz);
|
||||||
|
|
||||||
|
__shared__ int rank, rankMinusOne;
|
||||||
|
__shared__ Nd4jLong zLen, yLen, totalThreads, *coords, *xShape, *zShape, *xStride, *zStride, shift1, shift2, yStride0;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
coords = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
|
zLen = shape::length(zShapeInfo);
|
||||||
|
xShape = shape::shapeOf(const_cast<Nd4jLong*>(xShapeInfo));
|
||||||
|
zShape = shape::shapeOf(const_cast<Nd4jLong*>(zShapeInfo));
|
||||||
|
xStride = shape::stride(const_cast<Nd4jLong*>(xShapeInfo));
|
||||||
|
zStride = shape::stride(const_cast<Nd4jLong*>(zShapeInfo));
|
||||||
|
yStride0 = shape::stride(const_cast<Nd4jLong*>(yShapeInfo))[0];
|
||||||
|
rank = shape::rank(xShapeInfo);
|
||||||
|
zLen = shape::length(zShapeInfo);
|
||||||
|
yLen = 2 * rank;
|
||||||
|
rankMinusOne = rank - 1;
|
||||||
|
totalThreads = gridDim.x * blockDim.x;
|
||||||
|
shift1 = mode == 1 ? 0 : 1; // REFLECT : SYMMETRIC
|
||||||
|
shift2 = mode == 1 ? 2 : 1; // REFLECT : SYMMETRIC
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
auto xzCoord = coords + threadIdx.x * rank; // we use xzCoord storage both for x and z arrays
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
if(mode == 0) { // CONSTANT case
|
||||||
|
|
||||||
|
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
||||||
|
|
||||||
|
shape::index2coords(rank, zShape, i, zLen, xzCoord);
|
||||||
|
const auto zOffset = shape::getOffset(0, zShape, zStride, xzCoord, rank);
|
||||||
|
|
||||||
|
bool within = true;
|
||||||
|
for(int j = rankMinusOne; j >= 0; --j) {
|
||||||
|
if(xShape[j] == zShape[j]) continue;
|
||||||
|
const auto left = y[shape::getIndexOffset(yStride0 * j, yShapeInfo, yLen)];
|
||||||
|
if(xzCoord[j] < left || xzCoord[j] >= left + xShape[j]) {within = false; break;}
|
||||||
|
else {xzCoord[j] = xzCoord[j] - left;}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(within)
|
||||||
|
z[zOffset] = x[shape::getOffset(0, xShape, xStride, xzCoord, rank)];
|
||||||
|
else
|
||||||
|
z[zOffset] = padVal;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else { // REFLECT and SYMMETRIC cases
|
||||||
|
|
||||||
|
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
||||||
|
|
||||||
|
shape::index2coords(rank, zShape, i, zLen, xzCoord);
|
||||||
|
const auto zOffset = shape::getOffset(0, zShape, zStride, xzCoord, rank);
|
||||||
|
|
||||||
|
for(int j = rankMinusOne; j >= 0; --j) {
|
||||||
|
|
||||||
|
if(xShape[j] == zShape[j]) continue;
|
||||||
|
xzCoord[j] = xzCoord[j] - y[shape::getIndexOffset(yStride0 * j, yShapeInfo, yLen)]; // are ready to fill middle (within input dimension range)
|
||||||
|
if(xzCoord[j] < 0) xzCoord[j] = -xzCoord[j] - shift1; // means fill from left
|
||||||
|
else if(xzCoord[j] >= xShape[j]) xzCoord[j] = 2 * xShape[j] - xzCoord[j] - shift2; // means fill from right
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto xOffset = shape::getOffset(0, xShape, xStride, xzCoord, rank);
|
||||||
|
z[zOffset] = x[xOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename X, typename Y>
|
||||||
|
static void padCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
|
||||||
|
const int mode,
|
||||||
|
const void *vx, const Nd4jLong *xShapeInfo,
|
||||||
|
const void *vy, const Nd4jLong *yShapeInfo,
|
||||||
|
void *vz, const Nd4jLong *zShapeInfo,
|
||||||
|
const void* padVal) {
|
||||||
|
|
||||||
|
padCuda<X,Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(mode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, padVal);
|
||||||
|
}
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void padCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int mode, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const void* vPadVal), LIBND4J_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void pad(nd4j::LaunchContext * context, const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, const NDArray& padValue) {
|
||||||
|
|
||||||
|
PointersManager manager(context, "pad");
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&input, &paddings, &padValue});
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
|
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
const int sharedMem = 8 * threadsPerBlock * output.rankOf() + 128;
|
||||||
|
|
||||||
|
const auto xType = input.dataType();
|
||||||
|
const auto yType = paddings.dataType();
|
||||||
|
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, yType, padCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), mode, input.getSpecialBuffer(), input.getSpecialShapeInfo(), paddings.getSpecialBuffer(), paddings.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), padValue.getSpecialBuffer()), LIBND4J_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&output}, {&input, &paddings, &padValue});
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void mirrorPadLinearKernel(void const* vx, Nd4jLong* xShape, void* vz, Nd4jLong* zShape, Nd4jLong leftSide, Nd4jLong leftSideCorrected, Nd4jLong xLen, Nd4jLong len, Nd4jLong zLen) {
|
||||||
|
|
||||||
|
__shared__ T const* x;
|
||||||
|
__shared__ T* z;
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
x = reinterpret_cast<T const*>(vx);
|
||||||
|
z = reinterpret_cast<T*>(vz);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for(int i = start; i < zLen; i+= step) {
|
||||||
|
auto zIndex = shape::getIndexOffset(i, zShape, zLen);
|
||||||
|
auto xIndex = shape::getIndexOffset(len - i, xShape, xLen);
|
||||||
|
|
||||||
|
if (i < leftSide) // left side
|
||||||
|
xIndex = shape::getIndexOffset(leftSideCorrected - i, xShape, xLen);
|
||||||
|
|
||||||
|
else if(i >= leftSide && i < leftSide + xLen) // middle
|
||||||
|
xIndex = shape::getIndexOffset(i - leftSide, xShape, xLen);
|
||||||
|
|
||||||
|
// else // right side
|
||||||
|
// z[i] = x[len - i];
|
||||||
|
z[zIndex] = x[xIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F, typename I>
|
||||||
|
static __global__ void mirrorPadKernel(void const* vx, Nd4jLong* xShape, void* vz, Nd4jLong* zShape, Nd4jLong outLen, void const* paddings, Nd4jLong* paddingShape, int reflBorder) {
|
||||||
|
|
||||||
|
__shared__ F const* x;
|
||||||
|
__shared__ I const* pads;
|
||||||
|
__shared__ F* z;
|
||||||
|
__shared__ Nd4jLong zRank, rank;
|
||||||
|
__shared__ Nd4jLong* xShapeOf, *xStrideOf, *padsShapeOf, *padsStrideOf;
|
||||||
|
__shared__ Nd4jLong* zShapeOf, *zStrideOf;
|
||||||
|
__shared__ Nd4jLong* xIdx;
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
xIdx = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
|
rank = shape::rank(xShape);
|
||||||
|
|
||||||
|
x = reinterpret_cast<F const*>(vx);//
|
||||||
|
pads = reinterpret_cast<I const*>(paddings);
|
||||||
|
z = reinterpret_cast<F*>(vz);
|
||||||
|
xShapeOf = shape::shapeOf(xShape);
|
||||||
|
xStrideOf = shape::stride(xShape);
|
||||||
|
zShapeOf = shape::shapeOf(zShape);
|
||||||
|
zRank = shape::rank(zShape);
|
||||||
|
zStrideOf = shape::stride(zShape);
|
||||||
|
padsShapeOf = shape::shapeOf(paddingShape);
|
||||||
|
padsStrideOf = shape::stride(paddingShape);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for(Nd4jLong i = start; i < outLen; i+= step) {
|
||||||
|
auto xzCoord = xIdx + threadIdx.x * rank;
|
||||||
|
//auto zxCoord = xIdx + (threadIdx.x + threadIdx.x % 2 + 1) * rank;
|
||||||
|
|
||||||
|
shape::index2coords(rank, zShapeOf, i, xzCoord);
|
||||||
|
auto outOffset = shape::getOffset(0, zShapeOf, zStrideOf, xzCoord, rank);
|
||||||
|
// auto intStep = blockDim.y * gridDim.y;
|
||||||
|
for(int j = 0; j < rank; j++) {
|
||||||
|
|
||||||
|
const Nd4jLong inLen = shape::sizeAt(xShape, j);
|
||||||
|
Nd4jLong coords[2] = {j, 0};
|
||||||
|
auto padOffset = shape::getOffset(0, padsShapeOf, padsStrideOf, coords, 2); // padding already has rank 2
|
||||||
|
const auto leftSide = pads[padOffset];
|
||||||
|
const auto leftSideCorrected = leftSide - reflBorder;
|
||||||
|
const Nd4jLong len = 2 * (inLen - 1) + leftSide + reflBorder;
|
||||||
|
|
||||||
|
if(xzCoord[j] < leftSide) // left side
|
||||||
|
xzCoord[j] = leftSideCorrected - xzCoord[j];
|
||||||
|
|
||||||
|
else if(xzCoord[j] >= leftSide && xzCoord[j] < leftSide + inLen) // middle
|
||||||
|
xzCoord[j] = xzCoord[j] - leftSide;
|
||||||
|
|
||||||
|
else if (len > xzCoord[j]) // right side
|
||||||
|
xzCoord[j] = len - xzCoord[j];
|
||||||
|
else
|
||||||
|
xzCoord[j] = xzCoord[j] - len;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto inOffset = shape::getOffset(0, xShapeOf, xStrideOf, xzCoord, rank);
|
||||||
|
z[outOffset] = x[inOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename F, typename I>
|
||||||
|
static void mirrorPad_(nd4j::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode) {
|
||||||
|
// mode: 0 - REFLECT, else - SYMMETRIC
|
||||||
|
const int reflBorder = (bool)mode ? 1 : 0;
|
||||||
|
const int rank = input.rankOf();
|
||||||
|
const Nd4jLong outLen = output.lengthOf();
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&input, &paddings});
|
||||||
|
|
||||||
|
if(rank <= 1) {
|
||||||
|
|
||||||
|
const Nd4jLong inLen = input.lengthOf();
|
||||||
|
const auto leftSide = paddings.e<Nd4jLong>(0);
|
||||||
|
const auto leftSideCorrected = leftSide - reflBorder;
|
||||||
|
const Nd4jLong len = 2*(inLen-1) + leftSide + reflBorder;
|
||||||
|
|
||||||
|
mirrorPadLinearKernel<F><<<256, 512, 256, *stream>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftSide, leftSideCorrected, inLen, len, outLen);
|
||||||
|
nd4j::DebugHelper::checkErrorCode(stream, "helpers::mirrorPadLinearKernel(...) failed");
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
mirrorPadKernel<F, I><<<256, 256, 8192, *stream>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), outLen, paddings.getSpecialBuffer(), paddings.getSpecialShapeInfo(), reflBorder);
|
||||||
|
nd4j::DebugHelper::checkErrorCode(stream, "helpers::mirrorPadKernel(...) failed");
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({&output}, {&input, &paddings});
|
||||||
|
}
|
||||||
|
|
||||||
|
void mirrorPad(nd4j::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(input.dataType(), paddings.dataType(), mirrorPad_, (context, input, paddings, output, mode), LIBND4J_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void mirrorPad_, (nd4j::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode), LIBND4J_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,79 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include<ops/declarable/helpers/transforms.h>
|
||||||
|
#include <array/ResultSet.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <numeric>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include <helpers/TAD.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
#include <ConstantTadHelper.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
template <typename X, typename Y>
|
||||||
|
static _CUDA_G void scatterSimpleKernel(void *vx, Nd4jLong *xTadShape, Nd4jLong *xTadOffsets, Nd4jLong xLength, Nd4jLong numTads, void *vi, Nd4jLong *iShapeInfo, Nd4jLong iLength, void *vu, Nd4jLong *uShapeInfo, Nd4jLong uLength) {
|
||||||
|
auto u = reinterpret_cast<X*>(vu);
|
||||||
|
auto indices = reinterpret_cast<Y*>(vi);
|
||||||
|
|
||||||
|
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
for (int i = tid; i < iLength; i += blockDim.x * gridDim.x) {
|
||||||
|
auto x = reinterpret_cast<X*>(vx) + xTadOffsets[i];
|
||||||
|
auto idx = indices[shape::getIndexOffset(i, iShapeInfo, iLength)];
|
||||||
|
|
||||||
|
x[shape::getIndexOffset(idx, xTadShape, xLength)] = u[shape::getIndexOffset(i, uShapeInfo, uLength)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename X, typename Y>
|
||||||
|
void scatterSimple_(nd4j::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector<int>& dimensions) {
|
||||||
|
|
||||||
|
auto dims = ShapeUtils::evalDimsToExclude(input.rankOf(), dimensions);
|
||||||
|
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dims);
|
||||||
|
|
||||||
|
auto xLength = shape::length(packX.primaryShapeInfo());
|
||||||
|
auto iLength = indices.lengthOf();
|
||||||
|
auto uLength = updates.lengthOf();
|
||||||
|
|
||||||
|
scatterSimpleKernel<X,Y><<<256, 256, 1024, *context->getCudaStream()>>>(input.getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), xLength, packX.numberOfTads(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), iLength, updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), uLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void scatterSimple(nd4j::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector<int>& dimensions) {
|
||||||
|
auto xType = input.dataType();
|
||||||
|
auto yType = indices.dataType();
|
||||||
|
|
||||||
|
if (opId != 6)
|
||||||
|
throw std::runtime_error("scatterSimple: only copy op is supported");
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&input}, {&updates, &indices});
|
||||||
|
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, yType, scatterSimple_, (context, opId, input, updates, indices, dimensions), LIBND4J_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&input}, {&updates, &indices});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,133 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include<ops/declarable/helpers/transforms.h>
|
||||||
|
#include <array/ResultSet.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <numeric>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include <helpers/TAD.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
#include <ConstantTadHelper.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ static void scatterUpdateCuda(const int opCode, const int numOfInd,
|
||||||
|
void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets,
|
||||||
|
void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets,
|
||||||
|
const int* indexes) {
|
||||||
|
|
||||||
|
__shared__ T *x, *y;
|
||||||
|
__shared__ Nd4jLong arrLenX, arrLenY;
|
||||||
|
|
||||||
|
for (int e = 0; e < numOfInd; e++ ) {
|
||||||
|
|
||||||
|
const auto xIndex = indexes[e];
|
||||||
|
const bool isOwner = xIndex < gridDim.x ? blockIdx.x == xIndex : blockIdx.x == xIndex % gridDim.x;
|
||||||
|
|
||||||
|
if (!isOwner)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
x = reinterpret_cast<T*>(vx) + xOffsets[xIndex];
|
||||||
|
y = reinterpret_cast<T*>(vy) + yOffsets[e];
|
||||||
|
arrLenX = shape::length(xShapeInfo);
|
||||||
|
arrLenY = shape::length(yShapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (arrLenX != arrLenY)
|
||||||
|
return;
|
||||||
|
|
||||||
|
for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) {
|
||||||
|
|
||||||
|
const auto xOffset = shape::getIndexOffset(i, xShapeInfo, arrLenX);
|
||||||
|
const auto yOffset = shape::getIndexOffset(i, yShapeInfo, arrLenY);
|
||||||
|
|
||||||
|
switch (opCode) {
|
||||||
|
case 0:
|
||||||
|
x[xOffset] += y[yOffset];
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
x[xOffset] -= y[yOffset];
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
x[xOffset] *= y[yOffset];
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
x[xOffset] /= y[yOffset];
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
x[xOffset] = y[yOffset] - x[xOffset];
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
x[xOffset] = y[yOffset] / x[xOffset];
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
x[xOffset] = y[yOffset];
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__host__ static void scatterUpdateCudaLauncher(const cudaStream_t* stream, const int opCode, const int numOfInd, void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, const int* indexes) {
|
||||||
|
|
||||||
|
scatterUpdateCuda<T><<<512, 256, MAX_NUM_THREADS, *stream>>>(opCode, numOfInd, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, indexes);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void scatterUpdate(nd4j::LaunchContext* context, NDArray& input, NDArray& updates, const std::vector<int>* intArgs) {
|
||||||
|
|
||||||
|
const int opCode = (*intArgs)[0];
|
||||||
|
const int numOfDims = (*intArgs)[1];
|
||||||
|
const int numOfInd = (*intArgs)[2 + numOfDims];
|
||||||
|
|
||||||
|
std::vector<int> tadDimensions(numOfDims);
|
||||||
|
for (int e = 2; e < 2 + numOfDims; e++)
|
||||||
|
tadDimensions[e-2] = (*intArgs)[e];
|
||||||
|
|
||||||
|
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), tadDimensions);
|
||||||
|
auto packY = ConstantTadHelper::getInstance()->tadForDimensions(updates.getShapeInfo(), tadDimensions);
|
||||||
|
|
||||||
|
NDArray indices(const_cast<int*>(intArgs->data()) + numOfDims + 3, 'c', {numOfInd}, nd4j::DataType::INT32, context);
|
||||||
|
|
||||||
|
PointersManager manager(context, "scatterUpdate");
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&input}, {&input, &updates, &indices});
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), scatterUpdateCudaLauncher, (context->getCudaStream(), opCode, numOfInd, input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), updates.specialBuffer(), packY.platformShapeInfo(), packY.platformOffsets(), reinterpret_cast<int*>(indices.getSpecialBuffer())), LIBND4J_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&input}, {&input, &updates, &indices});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -33,163 +33,6 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
|
||||||
template<typename T>
|
|
||||||
__global__ static void concatCuda(const int numOfArrs, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo) {
|
|
||||||
|
|
||||||
__shared__ int arrIdx, blocksPerArr;
|
|
||||||
__shared__ T *x, *z;
|
|
||||||
__shared__ Nd4jLong *zShapeInfo, *xShapeInfo, arrLen, arrLenPerBlock, start, end;
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
|
|
||||||
blocksPerArr = (gridDim.x + numOfArrs - 1) / numOfArrs; // ceil
|
|
||||||
arrIdx = blockIdx.x / blocksPerArr;
|
|
||||||
|
|
||||||
x = reinterpret_cast<T*>(reinterpret_cast<void**>(pVx)[arrIdx]);
|
|
||||||
z = reinterpret_cast<T*>(reinterpret_cast<void**>(pVz)[arrIdx]);
|
|
||||||
xShapeInfo = reinterpret_cast<Nd4jLong**>(pxShapeInfo)[arrIdx];
|
|
||||||
zShapeInfo = reinterpret_cast<Nd4jLong**>(pzShapeInfo)[arrIdx];
|
|
||||||
arrLen = shape::length(xShapeInfo);
|
|
||||||
|
|
||||||
arrLenPerBlock = (arrLen + blocksPerArr - 1) / blocksPerArr; // ceil
|
|
||||||
|
|
||||||
start = (blockIdx.x % blocksPerArr) * arrLenPerBlock;
|
|
||||||
end = (start + arrLenPerBlock) > arrLen ? arrLen : (start + arrLenPerBlock);
|
|
||||||
}
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
for (Nd4jLong i = start + threadIdx.x; i < end; i += blockDim.x)
|
|
||||||
z[shape::getIndexOffset(i, zShapeInfo, arrLen)] = x[shape::getIndexOffset(i, xShapeInfo, arrLen)];
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
|
||||||
template<typename T>
|
|
||||||
__host__ static void concatCudaLauncher(const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo) {
|
|
||||||
|
|
||||||
concatCuda<T><<<512, 256, 1024, *stream>>>(numOfArrs, pVx, pxShapeInfo, pVz, pzShapeInfo);
|
|
||||||
}
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
|
||||||
// x - input, y - paddings, z - output
|
|
||||||
template<typename X, typename Y>
|
|
||||||
__global__ static void padCuda(const int mode,
|
|
||||||
const void *vx, const Nd4jLong *xShapeInfo,
|
|
||||||
const void *vy, const Nd4jLong *yShapeInfo,
|
|
||||||
void *vz, const Nd4jLong *zShapeInfo,
|
|
||||||
const void *vPadVal) {
|
|
||||||
|
|
||||||
const X padVal = *reinterpret_cast<const X*>(vPadVal);
|
|
||||||
|
|
||||||
const auto x = reinterpret_cast<const X*>(vx);
|
|
||||||
const auto y = reinterpret_cast<const Y*>(vy);
|
|
||||||
auto z = reinterpret_cast<X*>(vz);
|
|
||||||
|
|
||||||
__shared__ int rank, rankMinusOne;
|
|
||||||
__shared__ Nd4jLong zLen, yLen, totalThreads, *coords, *xShape, *zShape, *xStride, *zStride, shift1, shift2, yStride0;
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
|
|
||||||
extern __shared__ unsigned char shmem[];
|
|
||||||
coords = reinterpret_cast<Nd4jLong*>(shmem);
|
|
||||||
zLen = shape::length(zShapeInfo);
|
|
||||||
xShape = shape::shapeOf(const_cast<Nd4jLong*>(xShapeInfo));
|
|
||||||
zShape = shape::shapeOf(const_cast<Nd4jLong*>(zShapeInfo));
|
|
||||||
xStride = shape::stride(const_cast<Nd4jLong*>(xShapeInfo));
|
|
||||||
zStride = shape::stride(const_cast<Nd4jLong*>(zShapeInfo));
|
|
||||||
yStride0 = shape::stride(const_cast<Nd4jLong*>(yShapeInfo))[0];
|
|
||||||
rank = shape::rank(xShapeInfo);
|
|
||||||
zLen = shape::length(zShapeInfo);
|
|
||||||
yLen = 2 * rank;
|
|
||||||
rankMinusOne = rank - 1;
|
|
||||||
totalThreads = gridDim.x * blockDim.x;
|
|
||||||
shift1 = mode == 1 ? 0 : 1; // REFLECT : SYMMETRIC
|
|
||||||
shift2 = mode == 1 ? 2 : 1; // REFLECT : SYMMETRIC
|
|
||||||
}
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
auto xzCoord = coords + threadIdx.x * rank; // we use xzCoord storage both for x and z arrays
|
|
||||||
|
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
|
|
||||||
if(mode == 0) { // CONSTANT case
|
|
||||||
|
|
||||||
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
|
||||||
|
|
||||||
shape::index2coords(rank, zShape, i, zLen, xzCoord);
|
|
||||||
const auto zOffset = shape::getOffset(0, zShape, zStride, xzCoord, rank);
|
|
||||||
|
|
||||||
bool within = true;
|
|
||||||
for(int j = rankMinusOne; j >= 0; --j) {
|
|
||||||
if(xShape[j] == zShape[j]) continue;
|
|
||||||
const auto left = y[shape::getIndexOffset(yStride0 * j, yShapeInfo, yLen)];
|
|
||||||
if(xzCoord[j] < left || xzCoord[j] >= left + xShape[j]) {within = false; break;}
|
|
||||||
else {xzCoord[j] = xzCoord[j] - left;}
|
|
||||||
}
|
|
||||||
|
|
||||||
if(within)
|
|
||||||
z[zOffset] = x[shape::getOffset(0, xShape, xStride, xzCoord, rank)];
|
|
||||||
else
|
|
||||||
z[zOffset] = padVal;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else { // REFLECT and SYMMETRIC cases
|
|
||||||
|
|
||||||
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
|
||||||
|
|
||||||
shape::index2coords(rank, zShape, i, zLen, xzCoord);
|
|
||||||
const auto zOffset = shape::getOffset(0, zShape, zStride, xzCoord, rank);
|
|
||||||
|
|
||||||
for(int j = rankMinusOne; j >= 0; --j) {
|
|
||||||
|
|
||||||
if(xShape[j] == zShape[j]) continue;
|
|
||||||
xzCoord[j] = xzCoord[j] - y[shape::getIndexOffset(yStride0 * j, yShapeInfo, yLen)]; // are ready to fill middle (within input dimension range)
|
|
||||||
if(xzCoord[j] < 0) xzCoord[j] = -xzCoord[j] - shift1; // means fill from left
|
|
||||||
else if(xzCoord[j] >= xShape[j]) xzCoord[j] = 2 * xShape[j] - xzCoord[j] - shift2; // means fill from right
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto xOffset = shape::getOffset(0, xShape, xStride, xzCoord, rank);
|
|
||||||
z[zOffset] = x[xOffset];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
|
||||||
template<typename X, typename Y>
|
|
||||||
static void padCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
|
|
||||||
const int mode,
|
|
||||||
const void *vx, const Nd4jLong *xShapeInfo,
|
|
||||||
const void *vy, const Nd4jLong *yShapeInfo,
|
|
||||||
void *vz, const Nd4jLong *zShapeInfo,
|
|
||||||
const void* padVal) {
|
|
||||||
|
|
||||||
padCuda<X,Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(mode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, padVal);
|
|
||||||
}
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template void padCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int mode, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const void* vPadVal), LIBND4J_TYPES, INTEGER_TYPES);
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
|
||||||
void pad(nd4j::LaunchContext * context, const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, const NDArray& padValue) {
|
|
||||||
|
|
||||||
PointersManager manager(context, "pad");
|
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&output}, {&input, &paddings, &padValue});
|
|
||||||
|
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
|
||||||
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
|
||||||
const int sharedMem = 8 * threadsPerBlock * output.rankOf() + 128;
|
|
||||||
|
|
||||||
const auto xType = input.dataType();
|
|
||||||
const auto yType = paddings.dataType();
|
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, yType, padCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), mode, input.getSpecialBuffer(), input.getSpecialShapeInfo(), paddings.getSpecialBuffer(), paddings.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), padValue.getSpecialBuffer()), LIBND4J_TYPES, INTEGER_TYPES);
|
|
||||||
|
|
||||||
NDArray::registerSpecialUse({&output}, {&input, &paddings, &padValue});
|
|
||||||
manager.synchronize();
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__global__ static void invertPermutationCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo) {
|
__global__ static void invertPermutationCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo) {
|
||||||
|
@ -458,214 +301,6 @@ void tileBP(nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArr
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
|
||||||
template<typename T>
|
|
||||||
__global__ static void scatterUpdateCuda(const int opCode, const int numOfInd,
|
|
||||||
void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets,
|
|
||||||
void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets,
|
|
||||||
const int* indexes) {
|
|
||||||
|
|
||||||
__shared__ T *x, *y;
|
|
||||||
__shared__ Nd4jLong arrLenX, arrLenY;
|
|
||||||
|
|
||||||
for (int e = 0; e < numOfInd; e++ ) {
|
|
||||||
|
|
||||||
const auto xIndex = indexes[e];
|
|
||||||
const bool isOwner = xIndex < gridDim.x ? blockIdx.x == xIndex : blockIdx.x == xIndex % gridDim.x;
|
|
||||||
|
|
||||||
if (!isOwner)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
x = reinterpret_cast<T*>(vx) + xOffsets[xIndex];
|
|
||||||
y = reinterpret_cast<T*>(vy) + yOffsets[e];
|
|
||||||
arrLenX = shape::length(xShapeInfo);
|
|
||||||
arrLenY = shape::length(yShapeInfo);
|
|
||||||
}
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
if (arrLenX != arrLenY)
|
|
||||||
return;
|
|
||||||
|
|
||||||
for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) {
|
|
||||||
|
|
||||||
const auto xOffset = shape::getIndexOffset(i, xShapeInfo, arrLenX);
|
|
||||||
const auto yOffset = shape::getIndexOffset(i, yShapeInfo, arrLenY);
|
|
||||||
|
|
||||||
switch (opCode) {
|
|
||||||
case 0:
|
|
||||||
x[xOffset] += y[yOffset];
|
|
||||||
break;
|
|
||||||
case 1:
|
|
||||||
x[xOffset] -= y[yOffset];
|
|
||||||
break;
|
|
||||||
case 2:
|
|
||||||
x[xOffset] *= y[yOffset];
|
|
||||||
break;
|
|
||||||
case 3:
|
|
||||||
x[xOffset] /= y[yOffset];
|
|
||||||
break;
|
|
||||||
case 4:
|
|
||||||
x[xOffset] = y[yOffset] - x[xOffset];
|
|
||||||
break;
|
|
||||||
case 5:
|
|
||||||
x[xOffset] = y[yOffset] / x[xOffset];
|
|
||||||
break;
|
|
||||||
case 6:
|
|
||||||
x[xOffset] = y[yOffset];
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
__host__ static void scatterUpdateCudaLauncher(const cudaStream_t* stream, const int opCode, const int numOfInd, void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, const int* indexes) {
|
|
||||||
|
|
||||||
scatterUpdateCuda<T><<<512, 256, MAX_NUM_THREADS, *stream>>>(opCode, numOfInd, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, indexes);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
void scatterUpdate(nd4j::LaunchContext* context, NDArray& input, NDArray& updates, const std::vector<int>* intArgs) {
|
|
||||||
|
|
||||||
const int opCode = (*intArgs)[0];
|
|
||||||
const int numOfDims = (*intArgs)[1];
|
|
||||||
const int numOfInd = (*intArgs)[2 + numOfDims];
|
|
||||||
|
|
||||||
std::vector<int> tadDimensions(numOfDims);
|
|
||||||
for (int e = 2; e < 2 + numOfDims; e++)
|
|
||||||
tadDimensions[e-2] = (*intArgs)[e];
|
|
||||||
|
|
||||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), tadDimensions);
|
|
||||||
auto packY = ConstantTadHelper::getInstance()->tadForDimensions(updates.getShapeInfo(), tadDimensions);
|
|
||||||
|
|
||||||
NDArray indices(const_cast<int*>(intArgs->data()) + numOfDims + 3, 'c', {numOfInd}, nd4j::DataType::INT32, context);
|
|
||||||
|
|
||||||
PointersManager manager(context, "scatterUpdate");
|
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&input}, {&input, &updates, &indices});
|
|
||||||
BUILD_SINGLE_SELECTOR(input.dataType(), scatterUpdateCudaLauncher, (context->getCudaStream(), opCode, numOfInd, input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), updates.specialBuffer(), packY.platformShapeInfo(), packY.platformOffsets(), reinterpret_cast<int*>(indices.getSpecialBuffer())), LIBND4J_TYPES);
|
|
||||||
NDArray::registerSpecialUse({&input}, {&input, &updates, &indices});
|
|
||||||
|
|
||||||
manager.synchronize();
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
|
||||||
// x - input, y - indices, z - output
|
|
||||||
template<typename X, typename Y>
|
|
||||||
__global__ static void gatherNDCuda(const void *vx, const Nd4jLong *xShapeInfo,
|
|
||||||
const void *vy, const Nd4jLong *yShapeInfo,
|
|
||||||
void *vz, const Nd4jLong *zShapeInfo) {
|
|
||||||
|
|
||||||
const auto x = reinterpret_cast<const X*>(vx);
|
|
||||||
const auto y = reinterpret_cast<const Y*>(vy);
|
|
||||||
auto z = reinterpret_cast<X*>(vz);
|
|
||||||
|
|
||||||
__shared__ int xRank, yRank, zRank, maxRank, yLastDim;
|
|
||||||
__shared__ Nd4jLong zLen, totalThreads, *sharedMem;
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
|
|
||||||
extern __shared__ unsigned char shmem[];
|
|
||||||
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
|
||||||
|
|
||||||
xRank = shape::rank(xShapeInfo);
|
|
||||||
yRank = shape::rank(yShapeInfo);
|
|
||||||
zRank = shape::rank(zShapeInfo);
|
|
||||||
maxRank = nd4j::math::nd4j_max<int>(yRank, nd4j::math::nd4j_max<int>(xRank, zRank));
|
|
||||||
|
|
||||||
zLen = shape::length(zShapeInfo);
|
|
||||||
yLastDim = yShapeInfo[yRank];
|
|
||||||
|
|
||||||
totalThreads = gridDim.x * blockDim.x;
|
|
||||||
}
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
auto coord = sharedMem + threadIdx.x * maxRank;
|
|
||||||
|
|
||||||
Nd4jLong *zCoordStart, *xCoordStart;
|
|
||||||
|
|
||||||
if(yLastDim == xRank) {
|
|
||||||
zCoordStart = coord;
|
|
||||||
xCoordStart = coord;
|
|
||||||
}
|
|
||||||
if(zRank >= xRank) {
|
|
||||||
zCoordStart = coord;
|
|
||||||
xCoordStart = coord + zRank - xRank;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
zCoordStart = coord + xRank - zRank;
|
|
||||||
xCoordStart = coord;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
|
|
||||||
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
|
||||||
|
|
||||||
shape::index2coords(zRank, zShapeInfo + 1, i, zLen, zCoordStart);
|
|
||||||
|
|
||||||
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + zRank + 1, zCoordStart, zRank);
|
|
||||||
|
|
||||||
// last y coordinate
|
|
||||||
int coordToRestore;
|
|
||||||
if(yLastDim != xRank)
|
|
||||||
coordToRestore = static_cast<int>(zCoordStart[yRank - 1]);
|
|
||||||
|
|
||||||
zCoordStart[yRank - 1] = 0; // last y coordinate
|
|
||||||
const auto yOffset = shape::getOffset(0, yShapeInfo + 1, yShapeInfo + yRank + 1, zCoordStart, yRank);
|
|
||||||
|
|
||||||
//restore z coordinate
|
|
||||||
if(yLastDim != xRank)
|
|
||||||
zCoordStart[yRank - 1] = coordToRestore;
|
|
||||||
|
|
||||||
// construct coordinates for x
|
|
||||||
for(uint j = 0; j < yLastDim; ++j)
|
|
||||||
xCoordStart[j] = y[yOffset + j * yShapeInfo[2 * yRank]]; // last stride
|
|
||||||
|
|
||||||
const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + xRank + 1, xCoordStart, xRank);
|
|
||||||
|
|
||||||
z[zOffset] = x[xOffset];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
|
||||||
template<typename X, typename Y>
|
|
||||||
static void gatherNDCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
|
|
||||||
const void *vx, const Nd4jLong *xShapeInfo,
|
|
||||||
const void *vy, const Nd4jLong *yShapeInfo,
|
|
||||||
void *vz, const Nd4jLong *zShapeInfo) {
|
|
||||||
|
|
||||||
gatherNDCuda<X,Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo);
|
|
||||||
}
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template void gatherNDCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo), LIBND4J_TYPES, INTEGER_TYPES);
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
|
||||||
void gatherND(nd4j::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output) {
|
|
||||||
|
|
||||||
const int maxRank = nd4j::math::nd4j_max<int>(indices.rankOf(), nd4j::math::nd4j_max<int>(input.rankOf(), output.rankOf()));
|
|
||||||
|
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS;
|
|
||||||
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
|
||||||
const int sharedMem = 8 * threadsPerBlock * maxRank + 128;
|
|
||||||
|
|
||||||
const auto xType = input.dataType();
|
|
||||||
const auto yType = indices.dataType();
|
|
||||||
|
|
||||||
PointersManager manager(context, "gatherND");
|
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&output}, {&input, &indices});
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, yType, gatherNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), LIBND4J_TYPES, INTEGER_TYPES);
|
|
||||||
NDArray::registerSpecialUse({&output}, {&input, &indices});
|
|
||||||
|
|
||||||
manager.synchronize();
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// x - input, y - gradO, z - gradI
|
// x - input, y - gradO, z - gradI
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
|
@ -929,43 +564,6 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __global__ void swapShuffleKernel(T* input, Nd4jLong* shape, Nd4jLong firstDim, Nd4jLong len, nd4j::graph::RandomGenerator* rng) {
|
static __global__ void swapShuffleKernel(T* input, Nd4jLong* shape, Nd4jLong firstDim, Nd4jLong len, nd4j::graph::RandomGenerator* rng) {
|
||||||
auto tid = blockIdx.x * blockDim.x;
|
auto tid = blockIdx.x * blockDim.x;
|
||||||
|
@ -1093,207 +691,9 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void eye(nd4j::LaunchContext * context, NDArray& output) {
|
void eye(nd4j::LaunchContext * context, NDArray& output) {
|
||||||
|
|
||||||
output.setIdentity();
|
output.setIdentity();
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename T, typename Z>
|
|
||||||
static __global__ void global_mergeMaxIndex_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
|
||||||
auto output = reinterpret_cast<Z*>(voutput);
|
|
||||||
|
|
||||||
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
|
||||||
const auto step = gridDim.x * blockDim.x;
|
|
||||||
|
|
||||||
for (Nd4jLong e = tid; e < length; e += step) {
|
|
||||||
T mVal = -DataTypeUtils::max<T>();
|
|
||||||
Z mIdx(0);
|
|
||||||
|
|
||||||
for (int i = 0; i < numArrays; i++) {
|
|
||||||
auto x = reinterpret_cast<T*>(inArrs[i]);
|
|
||||||
auto xShape = reinterpret_cast<Nd4jLong *>(inShapes[i]);
|
|
||||||
auto val = x[shape::getIndexOffset(e, xShape, length)];;
|
|
||||||
if (mVal < val)
|
|
||||||
mIdx = static_cast<Z>(e);
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
output[shape::getIndexOffset(e, outputShape, length)] = mIdx;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename Z>
|
|
||||||
static void mergeMaxIndex_(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
|
||||||
std::vector<void *> inBuffers(inArrs.size());
|
|
||||||
std::vector<void *> inShapes(inArrs.size());
|
|
||||||
|
|
||||||
for (int e = 0; e < inArrs.size(); e++) {
|
|
||||||
inBuffers[e] = inArrs[e]->getSpecialBuffer();
|
|
||||||
inShapes[e] = inArrs[e]->getSpecialShapeInfo();
|
|
||||||
}
|
|
||||||
|
|
||||||
PointersManager manager(context, "mergeMaxIndex");
|
|
||||||
|
|
||||||
auto pInBuffers = reinterpret_cast<void **>(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void *)));
|
|
||||||
auto pInShapes = reinterpret_cast<void **>(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void *)));
|
|
||||||
auto length = output.lengthOf();
|
|
||||||
|
|
||||||
global_mergeMaxIndex_<T,Z><<<512, 512, 512, *context->getCudaStream()>>>(pInBuffers, pInShapes, (int) inArrs.size(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), length);
|
|
||||||
|
|
||||||
manager.synchronize();
|
|
||||||
}
|
|
||||||
|
|
||||||
void mergeMaxIndex(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
|
||||||
BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (context, inArrs, output), LIBND4J_TYPES, INTEGER_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template void mergeMaxIndex_, (nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output), LIBND4J_TYPES, INTEGER_TYPES);
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename T>
|
|
||||||
static __global__ void global_mergeMax_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
|
||||||
auto output = reinterpret_cast<T*>(voutput);
|
|
||||||
|
|
||||||
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
|
||||||
const auto step = gridDim.x * blockDim.x;
|
|
||||||
|
|
||||||
for (Nd4jLong e = tid; e < length; e += step) {
|
|
||||||
T mVal = -DataTypeUtils::max<T>();
|
|
||||||
|
|
||||||
for (int i = 0; i < numArrays; i++) {
|
|
||||||
auto x = reinterpret_cast<T*>(inArrs[i]);
|
|
||||||
auto xShape = reinterpret_cast<Nd4jLong *>(inShapes[i]);
|
|
||||||
auto val = x[shape::getIndexOffset(e, xShape, length)];;
|
|
||||||
if (mVal < val)
|
|
||||||
mVal = val;
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
output[shape::getIndexOffset(e, outputShape, length)] = mVal;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
static void mergeMax_(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
|
||||||
std::vector<void *> inBuffers(inArrs.size());
|
|
||||||
std::vector<void *> inShapes(inArrs.size());
|
|
||||||
|
|
||||||
for (int e = 0; e < inArrs.size(); e++) {
|
|
||||||
inBuffers[e] = inArrs[e]->getSpecialBuffer();
|
|
||||||
inShapes[e] = inArrs[e]->getSpecialShapeInfo();
|
|
||||||
}
|
|
||||||
|
|
||||||
PointersManager manager(context, "mergeMax");
|
|
||||||
|
|
||||||
auto pInBuffers = reinterpret_cast<void **>(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void *)));
|
|
||||||
auto pInShapes = reinterpret_cast<void **>(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void *)));
|
|
||||||
auto length = output.lengthOf();
|
|
||||||
|
|
||||||
global_mergeMax_<T><<<512, 512, 512, *context->getCudaStream()>>>(pInBuffers, pInShapes, (int) inArrs.size(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), length);
|
|
||||||
|
|
||||||
manager.synchronize();
|
|
||||||
}
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void mergeMax_, (nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
void mergeMax(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
|
||||||
BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (context, inArrs, output), LIBND4J_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename T>
|
|
||||||
static __global__ void global_mergeAvg_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
|
||||||
auto output = reinterpret_cast<T*>(voutput);
|
|
||||||
|
|
||||||
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
|
||||||
const auto step = gridDim.x * blockDim.x;
|
|
||||||
|
|
||||||
for (Nd4jLong e = tid; e < length; e += step) {
|
|
||||||
T sum(0.0f);
|
|
||||||
|
|
||||||
for (int i = 0; i < numArrays; i++) {
|
|
||||||
auto x = reinterpret_cast<T*>(inArrs[i]);
|
|
||||||
auto xShape = reinterpret_cast<Nd4jLong *>(inShapes[i]);
|
|
||||||
|
|
||||||
sum += x[shape::getIndexOffset(e, xShape, length)];
|
|
||||||
}
|
|
||||||
|
|
||||||
output[shape::getIndexOffset(e, outputShape, length)] = sum / numArrays;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
static void mergeAvg_(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
|
||||||
std::vector<void *> inBuffers(inArrs.size());
|
|
||||||
std::vector<void *> inShapes(inArrs.size());
|
|
||||||
|
|
||||||
for (int e = 0; e < inArrs.size(); e++) {
|
|
||||||
inBuffers[e] = inArrs[e]->getSpecialBuffer();
|
|
||||||
inShapes[e] = inArrs[e]->getSpecialShapeInfo();
|
|
||||||
}
|
|
||||||
|
|
||||||
PointersManager manager(context, "mergeAvg");
|
|
||||||
|
|
||||||
auto pInBuffers = reinterpret_cast<void **>(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void *)));
|
|
||||||
auto pInShapes = reinterpret_cast<void **>(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void *)));
|
|
||||||
auto length = output.lengthOf();
|
|
||||||
|
|
||||||
global_mergeAvg_<T><<<512, 512, 512, *context->getCudaStream()>>>(pInBuffers, pInShapes, (int) inArrs.size(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), length);
|
|
||||||
|
|
||||||
manager.synchronize();
|
|
||||||
}
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void mergeAvg_, (nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
void mergeAvg(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
|
||||||
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (context, inArrs, output), LIBND4J_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename T>
|
|
||||||
static __global__ void global_mergeAdd_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
|
||||||
auto output = reinterpret_cast<T*>(voutput);
|
|
||||||
|
|
||||||
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
|
||||||
const auto step = gridDim.x * blockDim.x;
|
|
||||||
|
|
||||||
for (Nd4jLong e = tid; e < length; e += step) {
|
|
||||||
T sum(0.0f);
|
|
||||||
|
|
||||||
for (int i = 0; i < numArrays; i++) {
|
|
||||||
auto x = reinterpret_cast<T*>(inArrs[i]);
|
|
||||||
auto xShape = reinterpret_cast<Nd4jLong *>(inShapes[i]);
|
|
||||||
|
|
||||||
sum += x[shape::getIndexOffset(e, xShape, length)];
|
|
||||||
}
|
|
||||||
|
|
||||||
output[shape::getIndexOffset(e, outputShape, length)] = sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
static void mergeAdd_(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
|
||||||
std::vector<void *> inBuffers(inArrs.size());
|
|
||||||
std::vector<void *> inShapes(inArrs.size());
|
|
||||||
|
|
||||||
for (int e = 0; e < inArrs.size(); e++) {
|
|
||||||
inBuffers[e] = inArrs[e]->getSpecialBuffer();
|
|
||||||
inShapes[e] = inArrs[e]->getSpecialShapeInfo();
|
|
||||||
}
|
|
||||||
|
|
||||||
PointersManager manager(context, "mergeAdd");
|
|
||||||
|
|
||||||
auto pInBuffers = reinterpret_cast<void **>(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void *)));
|
|
||||||
auto pInShapes = reinterpret_cast<void **>(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void *)));
|
|
||||||
auto length = output.lengthOf();
|
|
||||||
|
|
||||||
global_mergeAdd_<T><<<512, 512, 512, *context->getCudaStream()>>>(pInBuffers, pInShapes, (int) inArrs.size(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), length);
|
|
||||||
|
|
||||||
manager.synchronize();
|
|
||||||
}
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void mergeAdd_, (nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
void mergeAdd(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
|
||||||
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (context, inArrs, output), LIBND4J_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -1546,232 +946,6 @@ void eye(nd4j::LaunchContext * context, NDArray& output) {
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void clipByValue_, (nd4j::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output);, FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void clipByValue_, (nd4j::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output);, FLOAT_TYPES);
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename T>
|
|
||||||
static __global__ void mirrorPadLinearKernel(void const* vx, Nd4jLong* xShape, void* vz, Nd4jLong* zShape, Nd4jLong leftSide, Nd4jLong leftSideCorrected, Nd4jLong xLen, Nd4jLong len, Nd4jLong zLen) {
|
|
||||||
|
|
||||||
__shared__ T const* x;
|
|
||||||
__shared__ T* z;
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
x = reinterpret_cast<T const*>(vx);
|
|
||||||
z = reinterpret_cast<T*>(vz);
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
auto step = blockDim.x * gridDim.x;
|
|
||||||
|
|
||||||
for(int i = start; i < zLen; i+= step) {
|
|
||||||
auto zIndex = shape::getIndexOffset(i, zShape, zLen);
|
|
||||||
auto xIndex = shape::getIndexOffset(len - i, xShape, xLen);
|
|
||||||
|
|
||||||
if (i < leftSide) // left side
|
|
||||||
xIndex = shape::getIndexOffset(leftSideCorrected - i, xShape, xLen);
|
|
||||||
|
|
||||||
else if(i >= leftSide && i < leftSide + xLen) // middle
|
|
||||||
xIndex = shape::getIndexOffset(i - leftSide, xShape, xLen);
|
|
||||||
|
|
||||||
// else // right side
|
|
||||||
// z[i] = x[len - i];
|
|
||||||
z[zIndex] = x[xIndex];
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename F, typename I>
|
|
||||||
static __global__ void mirrorPadKernel(void const* vx, Nd4jLong* xShape, void* vz, Nd4jLong* zShape, Nd4jLong outLen, void const* paddings, Nd4jLong* paddingShape, int reflBorder) {
|
|
||||||
|
|
||||||
__shared__ F const* x;
|
|
||||||
__shared__ I const* pads;
|
|
||||||
__shared__ F* z;
|
|
||||||
__shared__ Nd4jLong zRank, rank;
|
|
||||||
__shared__ Nd4jLong* xShapeOf, *xStrideOf, *padsShapeOf, *padsStrideOf;
|
|
||||||
__shared__ Nd4jLong* zShapeOf, *zStrideOf;
|
|
||||||
__shared__ Nd4jLong* xIdx;
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
extern __shared__ unsigned char shmem[];
|
|
||||||
xIdx = reinterpret_cast<Nd4jLong*>(shmem);
|
|
||||||
rank = shape::rank(xShape);
|
|
||||||
|
|
||||||
x = reinterpret_cast<F const*>(vx);//
|
|
||||||
pads = reinterpret_cast<I const*>(paddings);
|
|
||||||
z = reinterpret_cast<F*>(vz);
|
|
||||||
xShapeOf = shape::shapeOf(xShape);
|
|
||||||
xStrideOf = shape::stride(xShape);
|
|
||||||
zShapeOf = shape::shapeOf(zShape);
|
|
||||||
zRank = shape::rank(zShape);
|
|
||||||
zStrideOf = shape::stride(zShape);
|
|
||||||
padsShapeOf = shape::shapeOf(paddingShape);
|
|
||||||
padsStrideOf = shape::stride(paddingShape);
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
|
||||||
auto step = blockDim.x * gridDim.x;
|
|
||||||
|
|
||||||
for(Nd4jLong i = start; i < outLen; i+= step) {
|
|
||||||
auto xzCoord = xIdx + threadIdx.x * rank;
|
|
||||||
//auto zxCoord = xIdx + (threadIdx.x + threadIdx.x % 2 + 1) * rank;
|
|
||||||
|
|
||||||
shape::index2coords(rank, zShapeOf, i, xzCoord);
|
|
||||||
auto outOffset = shape::getOffset(0, zShapeOf, zStrideOf, xzCoord, rank);
|
|
||||||
// auto intStep = blockDim.y * gridDim.y;
|
|
||||||
for(int j = 0; j < rank; j++) {
|
|
||||||
|
|
||||||
const Nd4jLong inLen = shape::sizeAt(xShape, j);
|
|
||||||
Nd4jLong coords[2] = {j, 0};
|
|
||||||
auto padOffset = shape::getOffset(0, padsShapeOf, padsStrideOf, coords, 2); // padding already has rank 2
|
|
||||||
const auto leftSide = pads[padOffset];
|
|
||||||
const auto leftSideCorrected = leftSide - reflBorder;
|
|
||||||
const Nd4jLong len = 2 * (inLen - 1) + leftSide + reflBorder;
|
|
||||||
|
|
||||||
if(xzCoord[j] < leftSide) // left side
|
|
||||||
xzCoord[j] = leftSideCorrected - xzCoord[j];
|
|
||||||
|
|
||||||
else if(xzCoord[j] >= leftSide && xzCoord[j] < leftSide + inLen) // middle
|
|
||||||
xzCoord[j] = xzCoord[j] - leftSide;
|
|
||||||
|
|
||||||
else if (len > xzCoord[j]) // right side
|
|
||||||
xzCoord[j] = len - xzCoord[j];
|
|
||||||
else
|
|
||||||
xzCoord[j] = xzCoord[j] - len;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto inOffset = shape::getOffset(0, xShapeOf, xStrideOf, xzCoord, rank);
|
|
||||||
z[outOffset] = x[inOffset];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename F, typename I>
|
|
||||||
static void mirrorPad_(nd4j::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode) {
|
|
||||||
// mode: 0 - REFLECT, else - SYMMETRIC
|
|
||||||
const int reflBorder = (bool)mode ? 1 : 0;
|
|
||||||
const int rank = input.rankOf();
|
|
||||||
const Nd4jLong outLen = output.lengthOf();
|
|
||||||
auto stream = context->getCudaStream();
|
|
||||||
NDArray::prepareSpecialUse({&output}, {&input, &paddings});
|
|
||||||
|
|
||||||
if(rank <= 1) {
|
|
||||||
|
|
||||||
const Nd4jLong inLen = input.lengthOf();
|
|
||||||
const auto leftSide = paddings.e<Nd4jLong>(0);
|
|
||||||
const auto leftSideCorrected = leftSide - reflBorder;
|
|
||||||
const Nd4jLong len = 2*(inLen-1) + leftSide + reflBorder;
|
|
||||||
|
|
||||||
mirrorPadLinearKernel<F><<<256, 512, 256, *stream>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftSide, leftSideCorrected, inLen, len, outLen);
|
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "helpers::mirrorPadLinearKernel(...) failed");
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
mirrorPadKernel<F, I><<<256, 256, 8192, *stream>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), outLen, paddings.getSpecialBuffer(), paddings.getSpecialShapeInfo(), reflBorder);
|
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "helpers::mirrorPadKernel(...) failed");
|
|
||||||
}
|
|
||||||
NDArray::registerSpecialUse({&output}, {&input, &paddings});
|
|
||||||
}
|
|
||||||
|
|
||||||
void mirrorPad(nd4j::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode) {
|
|
||||||
BUILD_DOUBLE_SELECTOR(input.dataType(), paddings.dataType(), mirrorPad_, (context, input, paddings, output, mode), LIBND4J_TYPES, INTEGER_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template void mirrorPad_, (nd4j::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode), LIBND4J_TYPES, INTEGER_TYPES);
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
|
|
||||||
|
|
||||||
const int numOfArrs = inArrs.size();
|
|
||||||
for(int i = 0; i < numOfArrs; ++i)
|
|
||||||
if(!inArrs[i]->isActualOnDeviceSide()) inArrs[i]->syncToDevice();
|
|
||||||
|
|
||||||
const int rank = inArrs[0]->rankOf();
|
|
||||||
const int rank2 = 2*rank;
|
|
||||||
std::vector<std::vector<Nd4jLong>> indices(numOfArrs, std::vector<Nd4jLong>(rank2,0));
|
|
||||||
|
|
||||||
// take into account indices for first array
|
|
||||||
indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis);
|
|
||||||
|
|
||||||
// loop through the rest of input arrays
|
|
||||||
for(int i = 1; i < numOfArrs; ++i) {
|
|
||||||
indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from
|
|
||||||
indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis); // index end with (excluding)
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<NDArray*> outSubArrs(numOfArrs);
|
|
||||||
for(int i = 0; i < numOfArrs; ++i)
|
|
||||||
outSubArrs[i] = new NDArray(output(indices[i], true));
|
|
||||||
|
|
||||||
// prepare arrays of pointers on buffers and shapes
|
|
||||||
std::vector<void*> hOutBuffers(numOfArrs), hInBuffers(numOfArrs);
|
|
||||||
std::vector<Nd4jLong*> hOutShapeInfo(numOfArrs), hInShapeInfo(numOfArrs);
|
|
||||||
for(int i = 0; i < numOfArrs; ++i) {
|
|
||||||
hOutBuffers[i] = outSubArrs[i]->getSpecialBuffer();
|
|
||||||
hInBuffers[i] = inArrs[i]->getSpecialBuffer();
|
|
||||||
hOutShapeInfo[i] = outSubArrs[i]->getSpecialShapeInfo();
|
|
||||||
hInShapeInfo[i] = inArrs[i]->getSpecialShapeInfo();
|
|
||||||
}
|
|
||||||
|
|
||||||
// allocate and copy all buffers and shapes arrays to global memory
|
|
||||||
PointersManager manager(context, "helpers::concat");
|
|
||||||
void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*));
|
|
||||||
void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*));
|
|
||||||
void* dInShapeInfo = manager.replicatePointer(hInShapeInfo.data(), hInShapeInfo.size() * sizeof(Nd4jLong*));
|
|
||||||
void* dOutShapeInfo = manager.replicatePointer(hOutShapeInfo.data(), hOutShapeInfo.size() * sizeof(Nd4jLong*));
|
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (numOfArrs, context->getCudaStream(), dInBuffers, dInShapeInfo, dOutBuffers, dOutShapeInfo), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
manager.synchronize();
|
|
||||||
|
|
||||||
for(int i = 0; i < numOfArrs; ++i)
|
|
||||||
delete outSubArrs[i];
|
|
||||||
|
|
||||||
for(int i = 0; i < numOfArrs; ++i)
|
|
||||||
inArrs[i]->tickReadHost();
|
|
||||||
|
|
||||||
output.tickWriteDevice();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
static _CUDA_G void scatterSimpleKernel(void *vx, Nd4jLong *xTadShape, Nd4jLong *xTadOffsets, Nd4jLong xLength, Nd4jLong numTads, void *vi, Nd4jLong *iShapeInfo, Nd4jLong iLength, void *vu, Nd4jLong *uShapeInfo, Nd4jLong uLength) {
|
|
||||||
auto u = reinterpret_cast<X*>(vu);
|
|
||||||
auto indices = reinterpret_cast<Y*>(vi);
|
|
||||||
|
|
||||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
|
||||||
for (int i = tid; i < iLength; i += blockDim.x * gridDim.x) {
|
|
||||||
auto x = reinterpret_cast<X*>(vx) + xTadOffsets[i];
|
|
||||||
auto idx = indices[shape::getIndexOffset(i, iShapeInfo, iLength)];
|
|
||||||
|
|
||||||
x[shape::getIndexOffset(idx, xTadShape, xLength)] = u[shape::getIndexOffset(i, uShapeInfo, uLength)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
void scatterSimple_(nd4j::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector<int>& dimensions) {
|
|
||||||
|
|
||||||
auto dims = ShapeUtils::evalDimsToExclude(input.rankOf(), dimensions);
|
|
||||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dims);
|
|
||||||
|
|
||||||
auto xLength = shape::length(packX.primaryShapeInfo());
|
|
||||||
auto iLength = indices.lengthOf();
|
|
||||||
auto uLength = updates.lengthOf();
|
|
||||||
|
|
||||||
scatterSimpleKernel<X,Y><<<256, 256, 1024, *context->getCudaStream()>>>(input.getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), xLength, packX.numberOfTads(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), iLength, updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), uLength);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void scatterSimple(nd4j::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector<int>& dimensions) {
|
|
||||||
auto xType = input.dataType();
|
|
||||||
auto yType = indices.dataType();
|
|
||||||
|
|
||||||
if (opId != 6)
|
|
||||||
throw std::runtime_error("scatterSimple: only copy op is supported");
|
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&input}, {&updates, &indices});
|
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, yType, scatterSimple_, (context, opId, input, updates, indices, dimensions), LIBND4J_TYPES, INTEGER_TYPES);
|
|
||||||
|
|
||||||
NDArray::registerSpecialUse({&input}, {&updates, &indices});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,12 +29,16 @@ if (CUDA_BLAS)
|
||||||
|
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
message("CUDA on Windows: enabling /EHsc")
|
message("CUDA on Windows: enabling /EHsc")
|
||||||
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /FS")
|
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /FS /w")
|
||||||
SET_TARGET_PROPERTIES(${LIBND4J_NAME} PROPERTIES COMPILER_FLAGS "/EHsc")
|
SET_TARGET_PROPERTIES(${LIBND4J_NAME} PROPERTIES COMPILER_FLAGS "/EHsc")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if ("${COMPUTE}" STREQUAL "all")
|
||||||
|
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70)
|
||||||
|
else()
|
||||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||||
endif()
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
# -fsanitize=address
|
# -fsanitize=address
|
||||||
# -fsanitize=leak
|
# -fsanitize=leak
|
||||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -229,22 +229,22 @@ TEST_F(DeclarableOpsTests14, test_empty_fill_1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, test_lstmBlockCell_1) {
|
TEST_F(DeclarableOpsTests14, test_lstmBlockCell_1) {
|
||||||
auto a = NDArrayFactory::create<float>('c', {1, 5}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f});
|
auto a = NDArrayFactory::create<double>('c', {1, 5}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f});
|
||||||
auto b = NDArrayFactory::create<float>('c', {1, 3});
|
auto b = NDArrayFactory::create<double>('c', {1, 3});
|
||||||
auto c = NDArrayFactory::create<float>('c', {1, 3});
|
auto c = NDArrayFactory::create<double>('c', {1, 3});
|
||||||
auto d = NDArrayFactory::create<float>('c', {8, 12}, {-0.15320599,-0.120416045,0.33126968,0.13921785,-0.32313538,-0.43956736,0.4756174,0.4335605,-0.5450856,-0.3943429,-0.28687626,0.068032146,-0.2793799,0.17298919,-0.36553562,-0.097853184,-0.2544747,-0.39872527,-0.14556861,-0.31479517,0.2559092,0.47166896,-0.31330687,0.47313118,0.5134543,-0.4678212,-0.12853557,0.26142156,0.43472284,-0.42842552,-0.1895876,0.538689,0.508651,-0.020272732,0.112327516,0.2704304,-0.046546757,0.32570732,-0.15148133,-0.19145513,0.18631572,-0.024152994,0.41603214,-0.3421499,0.0106860995,-0.2966229,-0.36713937,0.25841123,0.0843398,0.49082482,0.10800403,0.1874243,-0.26379472,-0.22531849,0.24924624,0.23119557,0.49940765,-0.051413506,0.20315129,-0.41888732,0.44097036,0.40453392,0.013338983,0.23434466,0.23942488,0.47894,-0.19898453,0.09253675,-0.032358468,-0.15213022,-0.3441009,-0.15600958,-0.08235118,0.12165731,-0.4481289,-0.4842423,-0.45797008,-0.4606034,0.08163166,-0.2981107,0.50207126,0.44195646,0.13850057,0.072246075,-0.34388685,0.030900061,0.35821778,0.47900867,0.5094063,0.23683065,0.18020362,-0.1369732,0.015235603,0.2786904,0.07954317,0.12543976});
|
auto d = NDArrayFactory::create<double>('c', {8, 12}, {-0.15320599,-0.120416045,0.33126968,0.13921785,-0.32313538,-0.43956736,0.4756174,0.4335605,-0.5450856,-0.3943429,-0.28687626,0.068032146,-0.2793799,0.17298919,-0.36553562,-0.097853184,-0.2544747,-0.39872527,-0.14556861,-0.31479517,0.2559092,0.47166896,-0.31330687,0.47313118,0.5134543,-0.4678212,-0.12853557,0.26142156,0.43472284,-0.42842552,-0.1895876,0.538689,0.508651,-0.020272732,0.112327516,0.2704304,-0.046546757,0.32570732,-0.15148133,-0.19145513,0.18631572,-0.024152994,0.41603214,-0.3421499,0.0106860995,-0.2966229,-0.36713937,0.25841123,0.0843398,0.49082482,0.10800403,0.1874243,-0.26379472,-0.22531849,0.24924624,0.23119557,0.49940765,-0.051413506,0.20315129,-0.41888732,0.44097036,0.40453392,0.013338983,0.23434466,0.23942488,0.47894,-0.19898453,0.09253675,-0.032358468,-0.15213022,-0.3441009,-0.15600958,-0.08235118,0.12165731,-0.4481289,-0.4842423,-0.45797008,-0.4606034,0.08163166,-0.2981107,0.50207126,0.44195646,0.13850057,0.072246075,-0.34388685,0.030900061,0.35821778,0.47900867,0.5094063,0.23683065,0.18020362,-0.1369732,0.015235603,0.2786904,0.07954317,0.12543976});
|
||||||
auto e = NDArrayFactory::create<float>('c', {3});
|
auto e = NDArrayFactory::create<double>('c', {3});
|
||||||
auto f = NDArrayFactory::create<float>('c', {3});
|
auto f = NDArrayFactory::create<double>('c', {3});
|
||||||
auto g = NDArrayFactory::create<float>('c', {3});
|
auto g = NDArrayFactory::create<double>('c', {3});
|
||||||
auto h = NDArrayFactory::create<float>('c', {12});
|
auto h = NDArrayFactory::create<double>('c', {12});
|
||||||
|
|
||||||
auto z0 = NDArrayFactory::create<float>('c', {1, 3});
|
auto z0 = NDArrayFactory::create<double>('c', {1, 3});
|
||||||
auto z1 = NDArrayFactory::create<float>('c', {1, 3});
|
auto z1 = NDArrayFactory::create<double>('c', {1, 3});
|
||||||
auto z2 = NDArrayFactory::create<float>('c', {1, 3});
|
auto z2 = NDArrayFactory::create<double>('c', {1, 3});
|
||||||
auto z3 = NDArrayFactory::create<float>('c', {1, 3});
|
auto z3 = NDArrayFactory::create<double>('c', {1, 3});
|
||||||
auto z4 = NDArrayFactory::create<float>('c', {1, 3});
|
auto z4 = NDArrayFactory::create<double>('c', {1, 3});
|
||||||
auto z5 = NDArrayFactory::create<float>('c', {1, 3});
|
auto z5 = NDArrayFactory::create<double>('c', {1, 3});
|
||||||
auto z6 = NDArrayFactory::create<float>('c', {1, 3});
|
auto z6 = NDArrayFactory::create<double>('c', {1, 3});
|
||||||
|
|
||||||
nd4j::ops::lstmBlockCell op;
|
nd4j::ops::lstmBlockCell op;
|
||||||
auto result = op.execute({&a, &b, &c, &d, &e, &f, &g, &h}, {&z0, &z1, &z2, &z3, &z4, &z5, &z6}, {1.0, -1.0}, {0}, {});
|
auto result = op.execute({&a, &b, &c, &d, &e, &f, &g, &h}, {&z0, &z1, &z2, &z3, &z4, &z5, &z6}, {1.0, -1.0}, {0}, {});
|
||||||
|
|
|
@ -1049,7 +1049,8 @@ TEST_F(NativeOpsTests, ConcatTest_1) {
|
||||||
//y.assign(2.);
|
//y.assign(2.);
|
||||||
x.syncToDevice();
|
x.syncToDevice();
|
||||||
z.syncToDevice();
|
z.syncToDevice();
|
||||||
auto dimension = NDArrayFactory::create<int>('c', {1}, {(int)0});
|
int d = 0;
|
||||||
|
auto dimension = NDArrayFactory::create<int>('c', {1}, {d});
|
||||||
auto dimensions = reinterpret_cast<int*>(dimension.buffer());
|
auto dimensions = reinterpret_cast<int*>(dimension.buffer());
|
||||||
//auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf());
|
//auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf());
|
||||||
auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf());
|
auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf());
|
||||||
|
@ -1087,7 +1088,8 @@ TEST_F(NativeOpsTests, ConcatTest_2) {
|
||||||
//y.assign(2.);
|
//y.assign(2.);
|
||||||
x.syncToDevice();
|
x.syncToDevice();
|
||||||
z.syncToDevice();
|
z.syncToDevice();
|
||||||
auto dimension = NDArrayFactory::create<int>('c', {1}, {(int)0});
|
int d = 0;
|
||||||
|
auto dimension = NDArrayFactory::create<int>('c', {1}, {d});
|
||||||
auto dimensions = reinterpret_cast<int*>(dimension.buffer());
|
auto dimensions = reinterpret_cast<int*>(dimension.buffer());
|
||||||
//auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf());
|
//auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf());
|
||||||
auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf());
|
auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf());
|
||||||
|
|
Loading…
Reference in New Issue