* fix pad javadoc and @see links. (#72) Signed-off-by: Robert Altena <Rob@Ra-ai.com> * [WIP] More fixes (#73) * special tests for ConstantTadHelper/ConstantShapeHelper Signed-off-by: raver119 <raver119@gmail.com> * release methods for data buffers Signed-off-by: raver119 <raver119@gmail.com> * delete temporary buffer Java side Signed-off-by: raver119 <raver119@gmail.com> * delete temporary buffer Java side Signed-off-by: raver119 <raver119@gmail.com> * delete temporary TadPack C++/Java side (#74) Signed-off-by: raver119 <raver119@gmail.com> * Zoo model TF import test updates (#75) * argLine fix, update compression_gru comment * updated comment for xception * undid but commented argLine change * updated xlnet comment * copyright headers * - new NDArray methods like()/ulike() (#77) - fix for depthwise_conv2d_bp + special test Signed-off-by: raver119 <raver119@gmail.com> * upsampling2d fix CUDA Signed-off-by: raver119 <raver119@gmail.com> * DL4J trace logging (#79) * MLN/CG trace logging for debugging Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tiny tweak Signed-off-by: AlexDBlack <blacka101@gmail.com> * strided_slice_bp shape fn leak fix Signed-off-by: raver119 <raver119@gmail.com> * SameDiff fixes and naming (#78) * remove SDVariable inplace methods * import methods * npe fix in OpVal * removed SameDiff inplace ops from tests * Naming updates, moved to centralized methods in SameDiff, should use op_#:# for everything * quick fixes * javadoc * SDVariable eval with placeholders * use regex match * better matching * initial commit Signed-off-by: raver119 <raver119@gmail.com> * initial commit Signed-off-by: raver119 <raver119@gmail.com> * fix javadoc. (#76) * fix javadoc. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * replace most @see with @link s. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * 4 additional tests Signed-off-by: raver119 <raver119@gmail.com> * launch context reorganization Signed-off-by: raver119 <raver119@gmail.com> * LaunchContext reorganization Signed-off-by: raver119 <raver119@gmail.com> * per-device LaunchContext Signed-off-by: raver119 <raver119@gmail.com> * Various DL4J/ND4J fixes (#81) * #7954 Force refresh of UI when switching tabs on overview page Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8017 Concurrent modification exception (synchronize) fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8033 Don't initialize updater in middle of writing memory crash dump Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8208 Fix shape checks for ND4J int[] creator methods Signed-off-by: AlexDBlack <blacka101@gmail.com> * #6385 #7992 Keras import naming fixes + cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8016 Upsampling3D - add NDHWC format support Signed-off-by: AlexDBlack <blacka101@gmail.com> * ContextBuffers as separate entity Signed-off-by: raver119 <raver119@gmail.com> * Refactor NativeOps.h to export C functions * Actually export functions from NativeOps.h * Adapt the Java wrappers in ND4J generated with JavaCPP * Create C wrappers for some of the C++ classes currently used by ND4J * ContextBuffers as separate entity Signed-off-by: raver119 <raver119@gmail.com> * remove duplicate code in createBufferDetached. (#83) Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Keras model import - updater lr fix (#84) * Keras model import - updater lr fix Signed-off-by: eraly <susan.eraly@gmail.com> * Keras model import - updater lr fix, cleanup Signed-off-by: eraly <susan.eraly@gmail.com> * ContextBuffers as separate entity Signed-off-by: raver119 <raver119@gmail.com> * ContextBuffers as separate entity Signed-off-by: raver119 <raver119@gmail.com> * Fix functions of OpaqueVariablesSet * thread-local buffers/affinity Signed-off-by: raver119 <raver119@gmail.com> * thread safety for LaunchContext Signed-off-by: raver119 <raver119@gmail.com> * more of thread safety Signed-off-by: raver119 <raver119@gmail.com> * one more multi threaded test Signed-off-by: raver119 <raver119@gmail.com> * SameDiff Convolution Config validation, better output methods (#82) * Conv Config validation & tests Signed-off-by: Ryan Nett <rnett@skymind.io> * stackOutputs utility method Signed-off-by: Ryan Nett <rnett@skymind.io> * use constructor for validation, support negative kernel sizes (infered from weights) Signed-off-by: Ryan Nett <rnett@skymind.io> * better output methods Signed-off-by: Ryan Nett <rnett@skymind.io> * move output to be with fit and evaluate Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * more fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * refactor duplicate code from pad methods. (#86) * refactor duplicate code from pad methods. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * replace switch with if. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Various ND4J/DL4J fixes and improvements (#87) * Reshape and reallocate - small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Reshape and reallocate - small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * #6488 ElementWiseVertex broadcast support Signed-off-by: AlexDBlack <blacka101@gmail.com> * Constructors and broadcast supported it Transforms.max/min Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8054 ElementWiseVertex now supports broadcast inputs Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8057 Nd4j.create overload dtype fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * #7551 ND4J Shape validation fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * [WIP] Numpy boolean import (#91) * numpy bool type Signed-off-by: raver119 <raver119@gmail.com> * numpy bool java side Signed-off-by: raver119 <raver119@gmail.com> * remove create method with unused parameter. (#89) * remove create method with unused parameter. * removed more unused methods. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * removing more unused code. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * last removal of unused code. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * remove createSparse methods. (#92) Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Various ND4J/DL4J fixes (#90) * Deprecate Old*Op instances Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8063 #8054 Broadcast exceptions + cleanup inplace ops Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove bad test condition Signed-off-by: AlexDBlack <blacka101@gmail.com> * #7993 Fix shape function issue in crop_and_resize op Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J SameDiff lambda layer fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8029 Fix for pnorm backprop math Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8038 Fix Op profiler NaN/Inf triggering + add tests (#93) Signed-off-by: AlexDBlack <blacka101@gmail.com> * createUninitializedDetached refactoring. (#94) * wip * update interface, add null implementations. * Breaking one test in a weird way. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * createUninitializedDetached refactored. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * cuda build fix for issues introduced by recent refactoring Signed-off-by: raver119 <raver119@gmail.com> * [WIP] More of CUDA (#95) * initial commit Signed-off-by: raver119 <raver119@gmail.com> * Implementation of hashcode cuda helper. Working edition. * Fixed parallel test input arangements. * Fixed tests for hashcode op. * Fixed shape calculation for image:crop_and_resize op and test. * NativeOps tests. Initial test suite. * Added tests for indexReduce methods. * Added test on execBroadcast with NDArray as dimensions. * Added test on execBroadcastBool with NDArray as dimensions. * Added tests on execPairwiseTransform and execPairwiseTransofrmBool. * Added tests for execReduce with scalar results. * Added reduce tests for non-empty dims array. * Added tests for reduce3. * Added tests for execScalar. * Added tests for execSummaryStats. * - provide cpu/cuda code for batch_to_space - testing it Signed-off-by: Yurii <yurii@skymind.io> * - remove old test for batch_to_space (had wrong format and numbers were not checked) Signed-off-by: Yurii <yurii@skymind.io> * Fixed complilation errors with test. * Added test for execTransformFloat. * Added test for execTransformSame. * Added test for execTransformBool. * Added test for execTransformStrict. * Added tests for execScalar/execScalarBool with TADs. * Added test for flatten. * - provide cpu/cuda code for space_to_Batch operaion Signed-off-by: Yurii <yurii@skymind.io> * Added test for concat. * comment unnecessary stuff in s_t_b Signed-off-by: Yurii <yurii@skymind.io> * Added test for specialConcat. * Added tests for memcpy/set routines. * Fixed pullRow cuda test. * Added pullRow test. * Added average test. * - correct typo in NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op...) Signed-off-by: Yurii <yurii@skymind.io> * - debugging and fixing cuda tests in JavaInteropTests file Signed-off-by: Yurii <yurii@skymind.io> * - correct some tests Signed-off-by: Yurii <yurii@skymind.io> * Added test for shuffle. * Fixed ops declarations. * Restored omp and added shuffle test. * Added convertTypes test. * Added tests for execRandom. Eliminated usage of RandomBuffer with NativeOps. * Added sort tests. * Added tests for execCustomOp. * - further debuging and fixing tests terminated with crash Signed-off-by: Yurii <yurii@skymind.io> * Added tests for calculateOutputShapes. * Addded Benchmarks test. * Commented benchmark tests. * change assertion Signed-off-by: raver119 <raver119@gmail.com> * Added tests for apply_sgd op. Added cpu helper for that op. * Implement cuda helper for aplly_sgd op. Fixed tests for NativeOps. * Added test for assign broadcastable. * Added tests for assign_bp op. * Added tests for axpy op. * - assign/execScalar/execTransformAny signature change - minor test fix Signed-off-by: raver119 <raver119@gmail.com> * Fixed axpy op. * meh Signed-off-by: raver119 <raver119@gmail.com> * - fix tests for nativeOps::concat Signed-off-by: Yurii <yurii@skymind.io> * sequential transform/scalar Signed-off-by: raver119 <raver119@gmail.com> * allow nested parallelism Signed-off-by: raver119 <raver119@gmail.com> * assign_bp leak fix Signed-off-by: raver119 <raver119@gmail.com> * block setRNG fix Signed-off-by: raver119 <raver119@gmail.com> * enable parallelism by default Signed-off-by: raver119 <raver119@gmail.com> * enable nested parallelism by default Signed-off-by: raver119 <raver119@gmail.com> * Added cuda implementation for row_count helper. * Added implementation for tnse gains op helper. * - take into account possible situations when input arrays are empty in reduce_ cuda stuff Signed-off-by: Yurii <yurii@skymind.io> * Implemented tsne/edge_forces op cuda-based helper. Parallelized cpu-based helper for edge_forces. * Added kernel for tsne/symmetrized op heleper. * Implementation of tsne/symmetrized op cuda helper. Working edition. * Eliminated waste printfs. * Added test for broadcastgradientargs op. * host-only fallback for empty reduce float Signed-off-by: raver119 <raver119@gmail.com> * - some tests fixes Signed-off-by: Yurii <yurii@skymind.io> * - correct the rest of reduce_ stuff Signed-off-by: Yurii <yurii@skymind.io> * - further correction of reduce_ stuff Signed-off-by: Yurii <yurii@skymind.io> * Added test for Cbow op. Also added cuda implementation for cbow helpers. * - improve code of stack operation for scalar case Signed-off-by: Yurii <yurii@skymind.io> * - provide cuda kernel for gatherND operation Signed-off-by: Yurii <yurii@skymind.io> * Implementation of cbow helpers with cuda kernels. * minor tests tweaks Signed-off-by: raver119 <raver119@gmail.com> * minor tests tweaks Signed-off-by: raver119 <raver119@gmail.com> * - further correction of cuda stuff Signed-off-by: Yurii <yurii@skymind.io> * Implementatation of cbow op helper with cuda kernels. Working edition. * Skip random testing for cudablas case. * lstmBlockCell context fix Signed-off-by: raver119 <raver119@gmail.com> * Added tests for ELU and ELU_BP ops. * Added tests for eq_scalar, gt_scalar, gte_scalar and lte_scalar ops. * Added tests for neq_scalar. * Added test for noop. * - further work on clipbynorm_bp Signed-off-by: Yurii <yurii@skymind.io> * - get rid of concat op call, use instead direct concat helper call Signed-off-by: Yurii <yurii@skymind.io> * lstmBlockCell context fix Signed-off-by: raver119 <raver119@gmail.com> * Added tests for lrelu and lrelu_bp. * Added tests for selu and selu_bp. * Fixed lrelu derivative helpers. * - some corrections in lstm Signed-off-by: Yurii <yurii@skymind.io> * operator * result shape fix Signed-off-by: raver119 <raver119@gmail.com> * - correct typo in lstmCell Signed-off-by: Yurii <yurii@skymind.io> * few tests fixed Signed-off-by: raver119 <raver119@gmail.com> * CUDA inverse broadcast bool fix Signed-off-by: raver119 <raver119@gmail.com> * disable MMAP test for CUDA Signed-off-by: raver119 <raver119@gmail.com> * BooleanOp syncToDevice Signed-off-by: raver119 <raver119@gmail.com> * meh Signed-off-by: raver119 <raver119@gmail.com> * additional data types for im2col/col2im Signed-off-by: raver119 <raver119@gmail.com> * Added test for firas_sparse op. * one more RandomBuffer test excluded Signed-off-by: raver119 <raver119@gmail.com> * Added tests for flatten op. * Added test for Floor op. * bunch of tests fixed Signed-off-by: raver119 <raver119@gmail.com> * mmulDot tests fixed Signed-off-by: raver119 <raver119@gmail.com> * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * Implemented floordiv_bp op and tests. * Fixed scalar case with cuda implementation for bds. * - work on cuda kernel for clip_by_norm backprop op is completed Signed-off-by: Yurii <yurii@skymind.io> * Eliminate cbow crach. * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * Eliminated abortion with batched nlp test. * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * Fixed shared flag initializing. * disabled bunch of cpu workspaces tests Signed-off-by: raver119 <raver119@gmail.com> * scalar operators fix: missing registerSpecialUse call Signed-off-by: raver119 <raver119@gmail.com> * Fixed logdet for cuda and tests. * - correct clipBynorm_bp Signed-off-by: Yurii <yurii@skymind.io> * Fixed crop_and_resize shape datatype. * - correct some mmul tests Signed-off-by: Yurii <yurii@skymind.io> * build fix Signed-off-by: raver119 <raver119@gmail.com> * exclude two methods for JNI Signed-off-by: raver119 <raver119@gmail.com> * exclude two methods for JNI Signed-off-by: raver119 <raver119@gmail.com> * exclude two methods for JNI (#97) Signed-off-by: raver119 <raver119@gmail.com> * temporary stack fix Signed-off-by: raver119 <raver119@gmail.com> * round robin affinity test Signed-off-by: raver119 <raver119@gmail.com> * get rid of legacy CudaContext methods Signed-off-by: raver119 <raver119@gmail.com> * get rid of legacy ContextPool classes/methods Signed-off-by: raver119 <raver119@gmail.com> * one legacy test removed Signed-off-by: raver119 <raver119@gmail.com> * few more fields rearranged Signed-off-by: raver119 <raver119@gmail.com> * OpaqueLaunchContext Signed-off-by: raver119 <raver119@gmail.com> * OpaqueLaunchContext++ Signed-off-by: raver119 <raver119@gmail.com> * more of OpaqueLaunchContext methods Signed-off-by: raver119 <raver119@gmail.com> * LaunchContext -> CudaContext Signed-off-by: raver119 <raver119@gmail.com> * AffinityManger changes Signed-off-by: raver119 <raver119@gmail.com> * AffinityManger changes Signed-off-by: raver119 <raver119@gmail.com> * cusolver handles Signed-off-by: raver119 <raver119@gmail.com> * typo Signed-off-by: raver119 <raver119@gmail.com> * cusolver method Signed-off-by: raver119 <raver119@gmail.com> * cusolver handle propagated Signed-off-by: raver119 <raver119@gmail.com> * blas/solver handles Signed-off-by: raver119 <raver119@gmail.com> * one more test Signed-off-by: raver119 <raver119@gmail.com> * legacy concat implementations replaced with new CustomOp Signed-off-by: raver119 <raver119@gmail.com> * one more test Signed-off-by: raver119 <raver119@gmail.com> * concat now uses way more blocks Signed-off-by: raver119 <raver119@gmail.com> * print Signed-off-by: raver119 <raver119@gmail.com> * no more triple template mmul Signed-off-by: raver119 <raver119@gmail.com> * bunch of kernels have dtypes reconsidered Signed-off-by: raver119 <raver119@gmail.com> * bunch of kernels have dtypes reconsidered Signed-off-by: raver119 <raver119@gmail.com> * bitonic sort reorganized Signed-off-by: raver119 <raver119@gmail.com> * bunch of cpu stuff removed from cuda scope Signed-off-by: raver119 <raver119@gmail.com> * bunch of cpu stuff removed from cuda scope Signed-off-by: raver119 <raver119@gmail.com> * type conversions moved to generic impl Signed-off-by: raver119 <raver119@gmail.com> * cpu data types pass Signed-off-by: raver119 <raver119@gmail.com> * non_max_suppression Signed-off-by: raver119 <raver119@gmail.com> * sortByValue fix Signed-off-by: raver119 <raver119@gmail.com> * ignore all mixed datatype tests for mmul Signed-off-by: raver119 <raver119@gmail.com> * special handling of OpProfiler exceptions Signed-off-by: raver119 <raver119@gmail.com> * - one failing concat test in cpp - Nd4j.tile now uses op internally Signed-off-by: raver119 <raver119@gmail.com> * get back dtype exception for legacy arrays deserialization Signed-off-by: raver119 <raver119@gmail.com>
2496 lines
145 KiB
C++
2496 lines
145 KiB
C++
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
//
|
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018
|
|
//
|
|
|
|
#include <ops/declarable/helpers/convolutions.h>
|
|
#include<ops/declarable/helpers/addBias.h>
|
|
#include <ops/declarable/helpers/im2col.h>
|
|
#include <ops/declarable/helpers/col2im.h>
|
|
#include <NDArrayFactory.h>
|
|
#include <MmulHelper.h>
|
|
|
|
namespace nd4j {
|
|
namespace ops {
|
|
|
|
#ifdef HAVE_MKLDNN
|
|
using namespace mkldnn;
|
|
|
|
void ConvolutionUtils::getMKLDNNMemoryDescPool2d(
|
|
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
|
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
|
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
|
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
|
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) {
|
|
mkldnn::memory::dims pool_src_tz = { bS, iC, iH, iW };
|
|
mkldnn::memory::dims pool_dst_tz = { bS, oC, oH, oW };
|
|
|
|
pool_strides = { sH, sW };
|
|
pool_kernel = { kH, kW };
|
|
pool_padding = { pH, pW };
|
|
pool_padding_r = { (oH - 1) * sH - iH + kH - pH,
|
|
(oW - 1) * sW - iW + kW - pW };
|
|
|
|
algorithm = poolingMode == 0 ? pooling_max
|
|
: extraParam0 == 0 ? pooling_avg_exclude_padding
|
|
: pooling_avg_include_padding;
|
|
auto type = mkldnn::memory::data_type::f32;
|
|
auto format = isNCHW ? mkldnn::memory::format::nchw : mkldnn::memory::format::nhwc;
|
|
auto supposed_to_be_any_format = mkldnn::memory::format::nChw8c; // doesn't work with "any"
|
|
|
|
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
|
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
|
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
|
user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
|
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCHW ? 0 : 0];
|
|
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCHW ? 1 : 3];
|
|
user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCHW ? 2 : 1];
|
|
user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCHW ? 3 : 2];
|
|
}
|
|
|
|
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
|
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
|
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
|
user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
|
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
|
|
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
|
|
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
|
|
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
|
|
}
|
|
|
|
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
|
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
|
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format);
|
|
user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
|
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCHW ? 0 : 0];
|
|
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCHW ? 1 : 3];
|
|
user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCHW ? 2 : 1];
|
|
user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCHW ? 3 : 2];
|
|
}
|
|
}
|
|
|
|
void ConvolutionUtils::getMKLDNNMemoryDescPool3d(
|
|
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
|
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
|
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
|
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
|
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) {
|
|
mkldnn::memory::dims pool_src_tz = { bS, iC, iD, iH, iW };
|
|
mkldnn::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW };
|
|
|
|
pool_strides = { sD, sH, sW };
|
|
pool_kernel = { kD, kH, kW };
|
|
pool_padding = { pD, pH, pW };
|
|
pool_padding_r = { (oD - 1) * sD - iD + kD - pD,
|
|
(oH - 1) * sH - iH + kH - pH,
|
|
(oW - 1) * sW - iW + kW - pW };
|
|
|
|
algorithm = poolingMode == 0 ? pooling_max
|
|
: extraParam0 == 0 ? pooling_avg_exclude_padding
|
|
: pooling_avg_include_padding;
|
|
auto type = mkldnn::memory::data_type::f32;
|
|
auto format = isNCDHW ? mkldnn::memory::format::ncdhw : mkldnn::memory::format::ndhwc;
|
|
auto supposed_to_be_any_format = mkldnn::memory::format::nCdhw8c; // doesn't work with "any"
|
|
|
|
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
|
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
|
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
|
user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
|
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCDHW ? 0 : 0];
|
|
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCDHW ? 1 : 4];
|
|
user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCDHW ? 2 : 1];
|
|
user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCDHW ? 3 : 2];
|
|
user_src_md->data.layout_desc.blocking.strides[0][4] = src->stridesOf()[isNCDHW ? 4 : 3];
|
|
}
|
|
|
|
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
|
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
|
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
|
user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
|
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
|
|
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
|
|
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
|
|
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
|
|
user_diff_src_md->data.layout_desc.blocking.strides[0][4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
|
|
}
|
|
|
|
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
|
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
|
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format);
|
|
user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
|
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCDHW ? 0 : 0];
|
|
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCDHW ? 1 : 4];
|
|
user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCDHW ? 2 : 1];
|
|
user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCDHW ? 3 : 2];
|
|
user_dst_md->data.layout_desc.blocking.strides[0][4] = dst->stridesOf()[isNCDHW ? 4 : 3];
|
|
}
|
|
}
|
|
#endif
|
|
|
|
//////////////////////////////////////////////////////////////////////////
|
|
// [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
|
template <typename T>
|
|
static void vol2col_(const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
|
|
|
|
const int bS = volume.sizeAt(0);
|
|
const int iC = volume.sizeAt(1);
|
|
const int iD = volume.sizeAt(2);
|
|
const int iH = volume.sizeAt(3);
|
|
const int iW = volume.sizeAt(4);
|
|
const int kD = columns.sizeAt(2);
|
|
const int kH = columns.sizeAt(3);
|
|
const int kW = columns.sizeAt(4);
|
|
const int oD = columns.sizeAt(5);
|
|
const int oH = columns.sizeAt(6);
|
|
const int oW = columns.sizeAt(7);
|
|
const Nd4jLong colStride0 = columns.stridesOf()[0];
|
|
const Nd4jLong colStride1 = columns.stridesOf()[1];
|
|
const Nd4jLong colStride2 = columns.stridesOf()[2];
|
|
const Nd4jLong colStride3 = columns.stridesOf()[3];
|
|
const Nd4jLong colStride4 = columns.stridesOf()[4];
|
|
const Nd4jLong colStride5 = columns.stridesOf()[5];
|
|
const Nd4jLong colStride6 = columns.stridesOf()[6];
|
|
const Nd4jLong colStride7 = columns.stridesOf()[7];
|
|
const Nd4jLong volStride0 = volume.stridesOf()[0];
|
|
const Nd4jLong volStride1 = volume.stridesOf()[1];
|
|
const Nd4jLong volStride2 = volume.stridesOf()[2];
|
|
const Nd4jLong volStride3 = volume.stridesOf()[3];
|
|
const Nd4jLong volStride4 = volume.stridesOf()[4];
|
|
|
|
T* colBuff = columns.bufferAsT<T>();
|
|
T* volBuff = const_cast<NDArray&>(volume).bufferAsT<T>();
|
|
|
|
T *col, *vol;
|
|
int volDep, volRow, volCol;
|
|
|
|
if (volume.ordering() == 'c' && columns.ordering() == 'c' && shape::strideDescendingCAscendingF(volume.getShapeInfo()) && shape::strideDescendingCAscendingF(columns.getShapeInfo()))
|
|
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(col, vol, volDep, volRow, volCol) collapse(2))
|
|
for (int b = 0; b < bS; ++b) {
|
|
for (int c = 0; c < iC; ++c) {
|
|
for (int kDep = 0; kDep < kD; ++kDep) {
|
|
for (int kRow = 0; kRow < kH; ++kRow) {
|
|
for (int kCol = 0; kCol < kW; ++kCol) {
|
|
for (int colD = 0; colD < oD; ++colD) {
|
|
for (int colH = 0; colH < oH; ++colH) {
|
|
for (int colW = 0; colW < oW; ++colW) {
|
|
|
|
volDep = (-pD + kDep * dD) + colD*sD;
|
|
volRow = (-pH + kRow * dH) + colH*sH;
|
|
volCol = (-pW + kCol * dW) + colW*sW;
|
|
|
|
col = colBuff + b*colStride0 + c*colStride1 + kDep*colStride2 + kRow*colStride3 + kCol*colStride4 + colD*colStride5 + colH*colStride6 + colW*colStride7;
|
|
|
|
if (static_cast<unsigned>(volDep) >= static_cast<unsigned>(iD) || static_cast<unsigned>(volRow) >= static_cast<unsigned>(iH) || static_cast<unsigned>(volCol) >= static_cast<unsigned>(iW))
|
|
*col = static_cast<T>(0.);
|
|
else {
|
|
vol = volBuff + b*volStride0 + c*volStride1 + volDep*volStride2 + volRow*volStride3 + volCol*volStride4;
|
|
*col = *vol;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
else
|
|
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(vol, col, volDep, volRow, volCol))
|
|
for (int b = 0; b < bS; b++) {
|
|
for (int colD = 0; colD < oD; ++colD) {
|
|
for (int colH = 0; colH < oH; ++colH) {
|
|
for (int colW = 0; colW < oW; ++colW) {
|
|
for (int c = 0; c < iC; ++c) {
|
|
for (int kDep = 0; kDep < kD; ++kDep) {
|
|
for (int kRow = 0; kRow < kH; ++kRow) {
|
|
for (int kCol = 0; kCol < kW; ++kCol) {
|
|
|
|
volDep = (-pD + kDep * dD) + colD*sD;
|
|
volRow = (-pH + kRow * dH) + colH*sH;
|
|
volCol = (-pW + kCol * dW) + colW*sW;
|
|
|
|
col = colBuff + b*colStride0 + c*colStride1 + kDep*colStride2 + kRow*colStride3 + kCol*colStride4 + colD*colStride5 + colH*colStride6 + colW*colStride7;
|
|
|
|
if (static_cast<unsigned>(volDep) >= static_cast<unsigned>(iD) || static_cast<unsigned>(volRow) >= static_cast<unsigned>(iH) || static_cast<unsigned>(volCol) >= static_cast<unsigned>(iW))
|
|
*col = static_cast<T>(0.);
|
|
else {
|
|
vol = volBuff + b*volStride0 + c*volStride1 + volDep*volStride2 + volRow*volStride3 + volCol*volStride4;
|
|
*col = *vol;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////
|
|
// [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW]
|
|
template <typename T>
|
|
static void col2vol_(const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
|
|
|
|
// initial zeroing of volume content
|
|
volume.nullify();
|
|
|
|
const int bS = volume.sizeAt(0);
|
|
const int iC = volume.sizeAt(1);
|
|
const int iD = volume.sizeAt(2);
|
|
const int iH = volume.sizeAt(3);
|
|
const int iW = volume.sizeAt(4);
|
|
const int kD = columns.sizeAt(2);
|
|
const int kH = columns.sizeAt(3);
|
|
const int kW = columns.sizeAt(4);
|
|
const int oD = columns.sizeAt(5);
|
|
const int oH = columns.sizeAt(6);
|
|
const int oW = columns.sizeAt(7);
|
|
const Nd4jLong colStride0 = columns.stridesOf()[0];
|
|
const Nd4jLong colStride1 = columns.stridesOf()[1];
|
|
const Nd4jLong colStride2 = columns.stridesOf()[2];
|
|
const Nd4jLong colStride3 = columns.stridesOf()[3];
|
|
const Nd4jLong colStride4 = columns.stridesOf()[4];
|
|
const Nd4jLong colStride5 = columns.stridesOf()[5];
|
|
const Nd4jLong colStride6 = columns.stridesOf()[6];
|
|
const Nd4jLong colStride7 = columns.stridesOf()[7];
|
|
const Nd4jLong volStride0 = volume.stridesOf()[0];
|
|
const Nd4jLong volStride1 = volume.stridesOf()[1];
|
|
const Nd4jLong volStride2 = volume.stridesOf()[2];
|
|
const Nd4jLong volStride3 = volume.stridesOf()[3];
|
|
const Nd4jLong volStride4 = volume.stridesOf()[4];
|
|
|
|
T* volBuff = volume.bufferAsT<T>();
|
|
T* colBuff = const_cast<NDArray&>(columns).bufferAsT<T>();
|
|
|
|
T* col, *vol;
|
|
int volDep, volRow, volCol;
|
|
|
|
if (volume.ordering() == 'c' && columns.ordering() == 'c' && shape::strideDescendingCAscendingF(volume.getShapeInfo()) && shape::strideDescendingCAscendingF(columns.getShapeInfo()))
|
|
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(col, vol, volDep, volRow, volCol) collapse(2))
|
|
for (int b = 0; b < bS; b++) {
|
|
for (int c = 0; c < iC; ++c) {
|
|
for (int kDep = 0; kDep < kD; ++kDep) {
|
|
for (int kRow = 0; kRow < kH; ++kRow) {
|
|
for (int kCol = 0; kCol < kW; ++kCol) {
|
|
for (int colD = 0; colD < oD; ++colD) {
|
|
for (int colH = 0; colH < oH; ++colH) {
|
|
for (int colW = 0; colW < oW; ++colW) {
|
|
|
|
volDep = -pD + kDep * dD + colD * sD;
|
|
volRow = -pH + kRow * dH + colH * sH;
|
|
volCol = -pW + kCol * dW + colW * sW;
|
|
|
|
if (static_cast<unsigned>(volDep) < static_cast<unsigned>(iD) && static_cast<unsigned>(volRow) < static_cast<unsigned>(iH) && static_cast<unsigned>(volCol) < static_cast<unsigned>(iW)) {
|
|
col = colBuff + b*colStride0 + c*colStride1 + kDep*colStride2 + kRow*colStride3 + kCol*colStride4 + colD*colStride5 + colH*colStride6 + colW*colStride7;
|
|
vol = volBuff + b*volStride0 + c*volStride1 + volDep*volStride2 + volRow*volStride3 + volCol*volStride4;
|
|
*vol += *col;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
else
|
|
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(vol, col, volDep, volRow, volCol))
|
|
for (int b = 0; b < bS; b++) {
|
|
for (int colD = 0; colD < oD; ++colD) {
|
|
for (int colH = 0; colH < oH; ++colH) {
|
|
for (int colW = 0; colW < oW; ++colW) {
|
|
for (int c = 0; c < iC; ++c) {
|
|
for (int kDep = 0; kDep < kD; ++kDep) {
|
|
for (int kRow = 0; kRow < kH; ++kRow) {
|
|
for (int kCol = 0; kCol < kW; ++kCol) {
|
|
|
|
volDep = (-pD + kDep * dD) + colD*sD;
|
|
volRow = (-pH + kRow * dH) + colH*sH;
|
|
volCol = (-pW + kCol * dW) + colW*sW;
|
|
|
|
if (static_cast<unsigned>(volDep) < static_cast<unsigned>(iD) && static_cast<unsigned>(volRow) < static_cast<unsigned>(iH) && static_cast<unsigned>(volCol) < static_cast<unsigned>(iW)) {
|
|
col = colBuff + b*colStride0 + c*colStride1 + kDep*colStride2 + kRow*colStride3 + kCol*colStride4 + colD*colStride5 + colH*colStride6 + colW*colStride7;
|
|
vol = volBuff + b*volStride0 + c*volStride1 + volDep*volStride2 + volRow*volStride3 + volCol*volStride4;
|
|
*vol += *col;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
#ifdef HAVE_MKLDNN
|
|
using namespace mkldnn;
|
|
|
|
void ConvolutionUtils::getMKLDNNMemoryDescConv2d(
|
|
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW,
|
|
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
|
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
|
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
|
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
|
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
|
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r) {
|
|
mkldnn::memory::dims conv_src_tz = { bS, iC, iH, iW };
|
|
mkldnn::memory::dims conv_weights_tz = { oC, iC, kH, kW };
|
|
mkldnn::memory::dims conv_bias_tz = { oC };
|
|
mkldnn::memory::dims conv_dst_tz = { bS, oC, oH, oW };
|
|
|
|
conv_strides = { sH, sW };
|
|
conv_padding = { pH, pW };
|
|
conv_padding_r = { (oH - 1) * sH - iH + kH - pH,
|
|
(oW - 1) * sW - iW + kW - pW };
|
|
|
|
auto type = mkldnn::memory::data_type::f32;
|
|
auto format = isNCHW ? mkldnn::memory::format::nchw : mkldnn::memory::format::nhwc;
|
|
auto formatw = mkldnn::memory::format::hwio;
|
|
|
|
if (src != nullptr && conv_src_md != nullptr) {
|
|
*conv_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format::any);
|
|
*user_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
|
|
user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
|
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCHW ? 0 : 0];
|
|
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCHW ? 1 : 3];
|
|
user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCHW ? 2 : 1];
|
|
user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCHW ? 3 : 2];
|
|
}
|
|
|
|
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
|
|
*conv_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format::any);
|
|
*user_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
|
|
user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
|
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
|
|
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
|
|
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
|
|
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
|
|
}
|
|
|
|
if (weights != nullptr && conv_weights_md != nullptr) {
|
|
*conv_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format::any);
|
|
*user_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
|
|
user_weights_md->data.format = mkldnn_blocked; // overrides "formatw = hwio"
|
|
user_weights_md->data.layout_desc.blocking.strides[0][0] = weights->stridesOf()[3];
|
|
user_weights_md->data.layout_desc.blocking.strides[0][1] = weights->stridesOf()[2];
|
|
user_weights_md->data.layout_desc.blocking.strides[0][2] = weights->stridesOf()[0];
|
|
user_weights_md->data.layout_desc.blocking.strides[0][3] = weights->stridesOf()[1];
|
|
}
|
|
|
|
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
|
|
*conv_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format::any);
|
|
*user_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
|
|
user_diff_weights_md->data.format = mkldnn_blocked; // overrides "formatw = hwio"
|
|
user_diff_weights_md->data.layout_desc.blocking.strides[0][0] = diff_weights->stridesOf()[3];
|
|
user_diff_weights_md->data.layout_desc.blocking.strides[0][1] = diff_weights->stridesOf()[2];
|
|
user_diff_weights_md->data.layout_desc.blocking.strides[0][2] = diff_weights->stridesOf()[0];
|
|
user_diff_weights_md->data.layout_desc.blocking.strides[0][3] = diff_weights->stridesOf()[1];
|
|
}
|
|
|
|
if (bias != nullptr && conv_bias_md != nullptr) {
|
|
*conv_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format::any);
|
|
*user_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format::x);
|
|
}
|
|
|
|
if (dst != nullptr && conv_dst_md != nullptr) {
|
|
*conv_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, mkldnn::memory::format::any);
|
|
*user_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, format);
|
|
user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
|
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCHW ? 0 : 0];
|
|
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCHW ? 1 : 3];
|
|
user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCHW ? 2 : 1];
|
|
user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCHW ? 3 : 2];
|
|
}
|
|
}
|
|
|
|
void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
|
|
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
|
|
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
|
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
|
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
|
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
|
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
|
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r) {
|
|
mkldnn::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
|
|
mkldnn::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
|
|
mkldnn::memory::dims conv_bias_tz = { oC };
|
|
mkldnn::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
|
|
|
|
conv_strides = { sD, sH, sW };
|
|
conv_padding = { pD, pH, pW };
|
|
conv_padding_r = { (oD - 1) * sD - iD + kD - pD,
|
|
(oH - 1) * sH - iH + kH - pH,
|
|
(oW - 1) * sW - iW + kW - pW };
|
|
|
|
auto type = mkldnn::memory::data_type::f32;
|
|
auto format = isNCDHW ? mkldnn::memory::format::ncdhw : mkldnn::memory::format::ndhwc;
|
|
auto formatw = mkldnn::memory::format::dhwio;
|
|
|
|
if (src != nullptr && conv_src_md != nullptr) {
|
|
*conv_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format::any);
|
|
*user_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
|
|
user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
|
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCDHW ? 0 : 0];
|
|
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCDHW ? 1 : 4];
|
|
user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCDHW ? 2 : 1];
|
|
user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCDHW ? 3 : 2];
|
|
user_src_md->data.layout_desc.blocking.strides[0][4] = src->stridesOf()[isNCDHW ? 4 : 3];
|
|
}
|
|
|
|
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
|
|
*conv_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format::any);
|
|
*user_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
|
|
user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
|
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
|
|
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
|
|
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
|
|
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
|
|
user_diff_src_md->data.layout_desc.blocking.strides[0][4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
|
|
}
|
|
|
|
if (weights != nullptr && conv_weights_md != nullptr) {
|
|
*conv_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format::any);
|
|
*user_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
|
|
user_weights_md->data.format = mkldnn_blocked; // overrides "formatw = dhwio"
|
|
user_weights_md->data.layout_desc.blocking.strides[0][0] = weights->stridesOf()[4];
|
|
user_weights_md->data.layout_desc.blocking.strides[0][1] = weights->stridesOf()[3];
|
|
user_weights_md->data.layout_desc.blocking.strides[0][2] = weights->stridesOf()[0];
|
|
user_weights_md->data.layout_desc.blocking.strides[0][3] = weights->stridesOf()[1];
|
|
user_weights_md->data.layout_desc.blocking.strides[0][4] = weights->stridesOf()[2];
|
|
}
|
|
|
|
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
|
|
*conv_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format::any);
|
|
*user_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
|
|
user_diff_weights_md->data.format = mkldnn_blocked; // overrides "formatw = dhwio"
|
|
user_diff_weights_md->data.layout_desc.blocking.strides[0][0] = diff_weights->stridesOf()[4];
|
|
user_diff_weights_md->data.layout_desc.blocking.strides[0][1] = diff_weights->stridesOf()[3];
|
|
user_diff_weights_md->data.layout_desc.blocking.strides[0][2] = diff_weights->stridesOf()[0];
|
|
user_diff_weights_md->data.layout_desc.blocking.strides[0][3] = diff_weights->stridesOf()[1];
|
|
user_diff_weights_md->data.layout_desc.blocking.strides[0][4] = diff_weights->stridesOf()[2];
|
|
}
|
|
|
|
if (bias != nullptr && conv_bias_md != nullptr) {
|
|
*conv_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format::any);
|
|
*user_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format::x);
|
|
}
|
|
|
|
if (dst != nullptr && conv_dst_md != nullptr) {
|
|
*conv_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, mkldnn::memory::format::any);
|
|
*user_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, format);
|
|
user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
|
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCDHW ? 0 : 0];
|
|
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCDHW ? 1 : 4];
|
|
user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCDHW ? 2 : 1];
|
|
user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCDHW ? 3 : 2];
|
|
user_dst_md->data.layout_desc.blocking.strides[0][4] = dst->stridesOf()[isNCDHW ? 4 : 3];
|
|
}
|
|
}
|
|
#endif
|
|
|
|
//////////////////////////////////////////////////////////////////////////
|
|
template <typename X, typename Y>
|
|
static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
|
|
|
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
|
// weights [kH, kW, iC, oC] always
|
|
// bias [oC]
|
|
// output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
|
|
|
// kH filter(kernel) height
|
|
// kW filter(kernel) width
|
|
// sH strides height
|
|
// sW strides width
|
|
// pH paddings height
|
|
// pW paddings width
|
|
// dH dilations height
|
|
// dW dilations width
|
|
// isSameMode 0-VALID, 1-SAME
|
|
// isNCHW 1-NCHW, 0-NHWC
|
|
|
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
|
|
|
if(isSameMode) // SAME
|
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
|
|
|
#ifdef HAVE_MKLDNN
|
|
if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported<X, Y>()) {
|
|
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
|
|
if (streams.empty()) {
|
|
streams.push_back(MKLDNNStream("conv2d"));
|
|
}
|
|
|
|
if (streams[0].checkAndReset({input, weights, bias}, {output}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW})) {
|
|
mkldnn_memory_desc_t empty;
|
|
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
|
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
|
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
|
|
|
|
ConvolutionUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
|
|
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr, bias, output,
|
|
&conv_src_md, nullptr, &conv_weights_md, nullptr, &conv_bias_md, &conv_dst_md,
|
|
&user_src_md, nullptr, &user_weights_md, nullptr, &user_bias_md, &user_dst_md,
|
|
conv_strides, conv_padding, conv_padding_r);
|
|
|
|
auto conv_desc = bias != nullptr
|
|
? convolution_forward::desc(prop_kind::forward,
|
|
convolution_direct, conv_src_md, conv_weights_md, conv_bias_md,
|
|
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero)
|
|
: convolution_forward::desc(prop_kind::forward,
|
|
convolution_direct, conv_src_md, conv_weights_md,
|
|
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
|
|
|
|
auto engine = streams[0].getEngine();
|
|
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
|
|
auto user_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray*>(input)->buffer());
|
|
auto user_weights_memory = mkldnn::memory({user_weights_md, engine}, const_cast<NDArray*>(weights)->buffer());
|
|
auto user_dst_memory = mkldnn::memory({user_dst_md, engine}, output->buffer());
|
|
|
|
auto conv_src_memory = user_src_memory;
|
|
streams[0].addMemory(user_src_memory);
|
|
if (mkldnn::memory::primitive_desc(conv_prim_desc.src_primitive_desc())
|
|
!= user_src_memory.get_primitive_desc()) {
|
|
conv_src_memory = mkldnn::memory(conv_prim_desc.src_primitive_desc());
|
|
streams[0].addMemory(conv_src_memory);
|
|
streams[0].addOperation(reorder(user_src_memory, conv_src_memory));
|
|
}
|
|
|
|
auto conv_weights_memory = user_weights_memory;
|
|
streams[0].addMemory(user_weights_memory);
|
|
if (mkldnn::memory::primitive_desc(conv_prim_desc.weights_primitive_desc())
|
|
!= user_weights_memory.get_primitive_desc()) {
|
|
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_primitive_desc());
|
|
streams[0].addMemory(conv_weights_memory);
|
|
streams[0].addOperation(reorder(user_weights_memory, conv_weights_memory));
|
|
}
|
|
|
|
auto conv_dst_memory = user_dst_memory;
|
|
streams[0].addMemory(user_dst_memory);
|
|
if (mkldnn::memory::primitive_desc(conv_prim_desc.dst_primitive_desc())
|
|
!= user_dst_memory.get_primitive_desc()) {
|
|
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_primitive_desc());
|
|
streams[0].addMemory(conv_dst_memory);
|
|
}
|
|
|
|
if (bias != nullptr) {
|
|
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_primitive_desc(), const_cast<NDArray*>(bias)->buffer());
|
|
streams[0].addMemory(conv_bias_memory);
|
|
streams[0].addOperation(convolution_forward(conv_prim_desc, conv_src_memory, conv_weights_memory, conv_bias_memory, conv_dst_memory));
|
|
} else {
|
|
streams[0].addOperation(convolution_forward(conv_prim_desc, conv_src_memory, conv_weights_memory, conv_dst_memory));
|
|
}
|
|
|
|
if (mkldnn::memory::primitive_desc(conv_prim_desc.dst_primitive_desc())
|
|
!= user_dst_memory.get_primitive_desc()) {
|
|
streams[0].addOperation(reorder(conv_dst_memory, user_dst_memory));
|
|
}
|
|
}
|
|
|
|
streams[0].submitAndWait();
|
|
return;
|
|
}
|
|
#endif
|
|
nd4j_debug("MKL-DNN is not used for conv2d!\n", 0);
|
|
|
|
std::vector<int> permutForOutput;
|
|
|
|
if(isNCHW)
|
|
permutForOutput = {0, 3, 1, 2}; // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
|
else
|
|
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC
|
|
|
|
NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext());
|
|
NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW}
|
|
NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext());
|
|
|
|
//----- calculation of output -----//
|
|
auto ctx = block.launchContext();
|
|
helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
|
MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, {0,1,2}, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC]
|
|
|
|
//----- assign outTemp to output -----//
|
|
if(isNCHW) {
|
|
mmulResult.reshapei({bS, oH, oW, oC});
|
|
mmulResult.permutei(permutForOutput);
|
|
}
|
|
output->assign(mmulResult);
|
|
|
|
//----- add biases if required -----//
|
|
if(bias)
|
|
// output->applyBroadcast(broadcast::Add, {indIOioC}, bias);
|
|
helpers::addBias(*output, *bias, isNCHW);
|
|
|
|
if(!isNCHW)
|
|
delete input;
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////
|
|
template <typename X, typename Y>
|
|
static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
|
|
|
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
|
// weights [kH, kW, iC, oC] always
|
|
// bias [oC]
|
|
// gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
|
|
|
// gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
|
// gradW [kH, kW, iC, oC] always
|
|
// gradB [oC]
|
|
|
|
// kH filter(kernel) height
|
|
// kW filter(kernel) width
|
|
// sH strides height
|
|
// sW strides width
|
|
// pH paddings height
|
|
// pW paddings width
|
|
// dH dilations height
|
|
// dW dilations width
|
|
// isSameMode 0-VALID, 1-SAME
|
|
// isNCHW 0-NHWC, 1-NCHW
|
|
|
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
|
|
|
if(isSameMode) // SAME
|
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
|
|
|
#ifdef HAVE_MKLDNN
|
|
if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported<X, Y>()) {
|
|
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
|
|
if (streams.empty()) {
|
|
streams.push_back(MKLDNNStream("conv2d_bp_weights"));
|
|
streams.push_back(MKLDNNStream("conv2d_bp_data"));
|
|
}
|
|
|
|
bool resetW = streams[0].checkAndReset({input, weights, bias, gradO}, {gradI, gradW, gradB}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW});
|
|
bool resetI = streams[1].checkAndReset({input, weights, bias, gradO}, {gradI, gradW, gradB}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW});
|
|
if (resetW || resetI) {
|
|
mkldnn_memory_desc_t empty;
|
|
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
|
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
|
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
|
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
|
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
|
|
|
|
ConvolutionUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
|
|
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW, gradB, gradO,
|
|
&conv_src_md, &conv_diff_src_md, &conv_weights_md, &conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
|
|
&user_src_md, &user_diff_src_md, &user_weights_md, &user_diff_weights_md, &user_bias_md, &user_dst_md,
|
|
conv_strides, conv_padding, conv_padding_r);
|
|
|
|
auto conv_desc = gradB != nullptr
|
|
? convolution_forward::desc(prop_kind::forward,
|
|
convolution_direct, conv_src_md, conv_weights_md, conv_bias_md,
|
|
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero)
|
|
: convolution_forward::desc(prop_kind::forward,
|
|
convolution_direct, conv_src_md, conv_weights_md,
|
|
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
|
|
|
|
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, streams[0].getEngine());
|
|
|
|
if (gradW != nullptr) {
|
|
auto convW_desc = gradB != nullptr
|
|
? convolution_backward_weights::desc(
|
|
convolution_direct, conv_src_md, conv_diff_weights_md, conv_bias_md,
|
|
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero)
|
|
: convolution_backward_weights::desc(
|
|
convolution_direct, conv_src_md, conv_diff_weights_md,
|
|
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
|
|
|
|
auto engine = streams[0].getEngine();
|
|
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc);
|
|
auto userW_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray*>(input)->buffer());
|
|
auto userW_weights_memory = mkldnn::memory({user_diff_weights_md, engine}, gradW->buffer());
|
|
auto userW_dst_memory = mkldnn::memory({user_dst_md, engine}, const_cast<NDArray*>(gradO)->buffer());
|
|
|
|
auto convW_src_memory = userW_src_memory;
|
|
streams[0].addMemory(userW_src_memory);
|
|
if (mkldnn::memory::primitive_desc(convW_prim_desc.src_primitive_desc())
|
|
!= userW_src_memory.get_primitive_desc()) {
|
|
convW_src_memory = mkldnn::memory(convW_prim_desc.src_primitive_desc());
|
|
streams[0].addMemory(convW_src_memory);
|
|
streams[0].addOperation(reorder(userW_src_memory, convW_src_memory));
|
|
}
|
|
|
|
auto convW_weights_memory = userW_weights_memory;
|
|
streams[0].addMemory(userW_weights_memory);
|
|
if (mkldnn::memory::primitive_desc(convW_prim_desc.diff_weights_primitive_desc())
|
|
!= userW_weights_memory.get_primitive_desc()) {
|
|
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_primitive_desc());
|
|
streams[0].addMemory(convW_weights_memory);
|
|
}
|
|
|
|
auto convW_dst_memory = userW_dst_memory;
|
|
streams[0].addMemory(userW_dst_memory);
|
|
if (mkldnn::memory::primitive_desc(convW_prim_desc.diff_dst_primitive_desc())
|
|
!= userW_dst_memory.get_primitive_desc()) {
|
|
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_primitive_desc());
|
|
streams[0].addMemory(convW_dst_memory);
|
|
streams[0].addOperation(reorder(userW_dst_memory, convW_dst_memory));
|
|
}
|
|
|
|
if (gradB != nullptr) {
|
|
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_primitive_desc(), gradB->buffer());
|
|
streams[0].addMemory(convW_bias_memory);
|
|
streams[0].addOperation(convolution_backward_weights(convW_prim_desc, convW_src_memory, convW_dst_memory, convW_weights_memory, convW_bias_memory));
|
|
} else {
|
|
streams[0].addOperation(convolution_backward_weights(convW_prim_desc, convW_src_memory, convW_dst_memory, convW_weights_memory));
|
|
}
|
|
|
|
if (mkldnn::memory::primitive_desc(convW_prim_desc.diff_weights_primitive_desc())
|
|
!= userW_weights_memory.get_primitive_desc()) {
|
|
streams[0].addOperation(reorder(convW_weights_memory, userW_weights_memory));
|
|
}
|
|
}
|
|
|
|
if (gradI != nullptr) {
|
|
auto convI_desc =
|
|
convolution_backward_data::desc(
|
|
convolution_direct, conv_diff_src_md, conv_weights_md,
|
|
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
|
|
|
|
auto engine = streams[1].getEngine();
|
|
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc);
|
|
auto userI_src_memory = mkldnn::memory({user_diff_src_md, engine}, gradI->buffer());
|
|
auto userI_weights_memory = mkldnn::memory({user_weights_md, engine}, const_cast<NDArray*>(weights)->buffer());
|
|
auto userI_dst_memory = mkldnn::memory({user_dst_md, engine}, const_cast<NDArray*>(gradO)->buffer());
|
|
|
|
auto convI_src_memory = userI_src_memory;
|
|
streams[1].addMemory(userI_src_memory);
|
|
if (mkldnn::memory::primitive_desc(convI_prim_desc.diff_src_primitive_desc())
|
|
!= userI_src_memory.get_primitive_desc()) {
|
|
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_primitive_desc());
|
|
streams[1].addMemory(convI_src_memory);
|
|
}
|
|
|
|
auto convI_weights_memory = userI_weights_memory;
|
|
streams[1].addMemory(userI_weights_memory);
|
|
if (mkldnn::memory::primitive_desc(convI_prim_desc.weights_primitive_desc())
|
|
!= userI_weights_memory.get_primitive_desc()) {
|
|
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_primitive_desc());
|
|
streams[1].addMemory(convI_weights_memory);
|
|
streams[1].addOperation(reorder(userI_weights_memory, convI_weights_memory));
|
|
}
|
|
|
|
auto convI_dst_memory = userI_dst_memory;
|
|
streams[1].addMemory(userI_dst_memory);
|
|
if (mkldnn::memory::primitive_desc(convI_prim_desc.diff_dst_primitive_desc())
|
|
!= userI_dst_memory.get_primitive_desc()) {
|
|
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_primitive_desc());
|
|
streams[1].addMemory(convI_dst_memory);
|
|
streams[1].addOperation(reorder(userI_dst_memory, convI_dst_memory));
|
|
}
|
|
|
|
streams[1].addOperation(convolution_backward_data(convI_prim_desc, convI_dst_memory, convI_weights_memory, convI_src_memory));
|
|
|
|
if (mkldnn::memory::primitive_desc(convI_prim_desc.diff_src_primitive_desc())
|
|
!= userI_src_memory.get_primitive_desc()) {
|
|
streams[1].addOperation(reorder(convI_src_memory, userI_src_memory));
|
|
}
|
|
}
|
|
}
|
|
|
|
if (gradW != nullptr) {
|
|
streams[0].submitAndWait();
|
|
}
|
|
if (gradI != nullptr) {
|
|
streams[1].submitAndWait();
|
|
}
|
|
return;
|
|
}
|
|
#endif
|
|
nd4j_debug("MKL-DNN is not used for conv2d_bp!\n", 0);
|
|
|
|
std::vector<int> gradOaxesForDot;
|
|
|
|
if(!isNCHW) {
|
|
gradOaxesForDot = {0, 1, 2}; // bS, oH, oW
|
|
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
|
gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
|
} else {
|
|
gradOaxesForDot = {0, 2, 3}; // bS, oH, oW
|
|
}
|
|
|
|
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
|
|
|
// ----- calculation of gradW ----- //
|
|
if(gradW) {
|
|
auto ctx = block.launchContext();
|
|
helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
|
nd4j::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, {2, 0, 1, 3}); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC]
|
|
}
|
|
|
|
// ----- calculation of gradB ----- //
|
|
if(gradB) {
|
|
NDArray* gradBR = gradB;
|
|
if(gradB->rankOf() == 2)
|
|
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
|
gradO->reduceAlongDimension(reduce::Sum, gradBR, gradOaxesForDot); // sum over bS, oH, oW
|
|
if(gradBR != gradB)
|
|
delete gradBR;
|
|
}
|
|
|
|
//----- calculation of gradI -----//
|
|
nd4j::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2, 3, 1, 0, 4, 5}); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW]
|
|
|
|
helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
|
|
|
|
if(!isNCHW) {
|
|
delete input;
|
|
delete gradI;
|
|
}
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////
|
|
template <typename X, typename Y>
|
|
static void depthwiseConv2d_(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
|
|
|
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
|
// weights [kH, kW, iC, mC] always
|
|
// bias [oC] = iC*mC
|
|
// output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
|
|
|
|
// kH filter(kernel) height
|
|
// kW filter(kernel) width
|
|
// sH strides height
|
|
// sW strides width
|
|
// pH paddings height
|
|
// pW paddings width
|
|
// dH dilations height
|
|
// dW dilations width
|
|
// isSameMode 0-VALID, 1-SAME
|
|
// isNCHW 0-NCHW, 1-NHWC
|
|
|
|
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
|
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
|
mC = weights->sizeAt(indWmC); // channels multiplier
|
|
|
|
std::vector<std::vector<Nd4jLong>> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW]
|
|
std::vector<std::vector<Nd4jLong>> modifOutput;
|
|
std::vector<Nd4jLong> outReShape;
|
|
|
|
if(!isNCHW) {
|
|
outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC]
|
|
modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
|
|
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW]
|
|
}
|
|
else {
|
|
outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW]
|
|
modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
|
|
}
|
|
|
|
if(isSameMode) // SAME
|
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
|
|
|
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
|
NDArray outputReshaped = output->reshape(output->ordering(), outReShape);
|
|
|
|
helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
|
MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
|
|
|
|
if(bias)
|
|
output->applyBroadcast(broadcast::Add, {indIOioC}, bias);
|
|
|
|
if(!isNCHW)
|
|
delete input;
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////
|
|
template <typename X, typename Y>
|
|
static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
|
|
|
// input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
|
|
// weights [kH, kW, iC, mC] always
|
|
// bias [oC] = [iC*mC]
|
|
// gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
|
// gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
|
|
// gradW [kH, kW, iC, mC] always
|
|
// gradB [oC]
|
|
|
|
// kH filter(kernel) height
|
|
// kW filter(kernel) width
|
|
// sH strides height
|
|
// sW strides width
|
|
// pH paddings height
|
|
// pW paddings width
|
|
// dH dilations height
|
|
// dW dilations width
|
|
// isSameMode 0-VALID, 1-SAME
|
|
// isNCHW 0-NHWC, 1-NCHW
|
|
|
|
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
|
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
|
mC = weights->sizeAt(indWmC); // channels multiplier
|
|
|
|
std::vector<std::vector<Nd4jLong>> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW]
|
|
std::vector<std::vector<Nd4jLong>> modifGradO1, modifGradO2;
|
|
std::vector<Nd4jLong> gradOreShape;
|
|
|
|
if(!isNCHW) {
|
|
gradOreShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC]
|
|
modifGradO1 = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
|
|
modifGradO2 = {{3,0,1,2},{iC, mC, bS*oH*oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW]
|
|
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW]
|
|
gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW]
|
|
}
|
|
else {
|
|
gradOreShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW]
|
|
modifGradO1 = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
|
|
modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW]
|
|
}
|
|
|
|
if(isSameMode) // SAME
|
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
|
|
|
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
|
NDArray gradOreshaped = gradO->reshape(gradO->ordering(), gradOreShape);
|
|
|
|
// ----- calculation of gradW and gradB ----- //
|
|
|
|
helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
|
nd4j::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, {{2,0,1,3},{iC,kH*kW,mC}}); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC]
|
|
|
|
// ----- calculation of gradB ----- //
|
|
if(gradB) {
|
|
NDArray* gradBR = gradB;
|
|
if(gradB->rankOf() == 2)
|
|
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
|
gradO->reduceAlongDimension(reduce::Sum, gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW
|
|
|
|
if(gradBR != gradB)
|
|
delete gradBR;
|
|
}
|
|
|
|
//----- calculation of gradI -----//
|
|
nd4j::MmulHelper::tensorDot(weights, gradO, &columns, {{2,0,1,3},{iC,kH*kW,mC}}, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW]
|
|
helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
|
|
|
|
if(!isNCHW) {
|
|
delete input;
|
|
delete gradI;
|
|
}
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////
|
|
template <typename X, typename Y>
|
|
static void sconv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
|
|
|
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
|
// weightsDepth [kH, kW, iC, mC] always
|
|
// weightsPoint [1, 1, iC*mC, oC] always
|
|
// bias [oC], oC = iC*mC if weightsPoint=nullptr
|
|
// output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
|
|
|
// kH filter(kernel) height
|
|
// kW filter(kernel) width
|
|
// sH strides height
|
|
// sW strides width
|
|
// pH paddings height
|
|
// pW paddings width
|
|
// dH dilations height
|
|
// dW dilations width
|
|
// isSameMode 0-VALID, 1-SAME
|
|
// isNCHW 1-NCHW, 0-NHWC
|
|
|
|
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
|
|
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
|
mC = weightsDepth->sizeAt(indWmC); // channels multiplier
|
|
|
|
NDArray* outputDepth = output;
|
|
if(weightsPoint) // if pointwise convolution is expected
|
|
outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector<Nd4jLong>({bS, oH, oW, iC*mC}) : std::vector<Nd4jLong>({bS, iC*mC, oH, oW}), input->dataType(), input->getContext());
|
|
|
|
// ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- //
|
|
ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
|
|
|
// ----- perform pointwise convolution (oH = iH, oW = iW) ----- //
|
|
if (weightsPoint) {
|
|
ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH, oW=iW
|
|
delete outputDepth;
|
|
}
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////
|
|
template <typename T>
|
|
static void upsampling2d_(const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) {
|
|
// input has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC)
|
|
// output has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC)
|
|
|
|
const T* x = input.bufferAsT<T>();
|
|
T* z = output.bufferAsT<T>();
|
|
|
|
const uint dimIH = isNCHW ? 2 : 1;
|
|
const uint dimIC = isNCHW ? 1 : 3;
|
|
|
|
const uint bS = input.sizeAt(0);
|
|
const uint iC = input.sizeAt(dimIC);
|
|
const uint oH = output.sizeAt(dimIH);
|
|
const uint oW = output.sizeAt(dimIH + 1);
|
|
|
|
const Nd4jLong xStride0 = input.stridesOf()[0];
|
|
const Nd4jLong xStride1 = input.stridesOf()[dimIC];
|
|
const Nd4jLong xStride2 = input.stridesOf()[dimIH];
|
|
const Nd4jLong xStride3 = input.stridesOf()[dimIH + 1];
|
|
|
|
const Nd4jLong zStride0 = output.stridesOf()[0];
|
|
const Nd4jLong zStride1 = output.stridesOf()[dimIC];
|
|
const Nd4jLong zStride2 = output.stridesOf()[dimIH];
|
|
const Nd4jLong zStride3 = output.stridesOf()[dimIH + 1];
|
|
|
|
uint xCoord2, xCoord3;
|
|
// loop through output array
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(collapse(4) private(xCoord2, xCoord3))
|
|
for(uint b = 0; b < bS; ++b) {
|
|
for(uint c = 0; c < iC; ++c) {
|
|
for(uint h = 0; h < oH ; ++h) {
|
|
for(uint w = 0; w < oW ; ++w) {
|
|
|
|
xCoord2 = h / factorH;
|
|
xCoord3 = w / factorW;
|
|
|
|
z[b*zStride0 + c*zStride1 + h*zStride2 + w*zStride3] = x[b*xStride0 + c*xStride1 + xCoord2*xStride2 + xCoord3*xStride3];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////
|
|
template <typename T>
|
|
static void upsampling3d_(const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) {
|
|
// input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC)
|
|
// output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC)
|
|
|
|
const T* x = input.bufferAsT<T>();
|
|
T* z = output.bufferAsT<T>();
|
|
|
|
const uint dimID = isNCDHW ? 2 : 1;
|
|
const uint dimIC = isNCDHW ? 1 : 4;
|
|
|
|
const uint bS = input.sizeAt(0);
|
|
const uint iC = input.sizeAt(dimIC);
|
|
const uint oD = output.sizeAt(dimID);
|
|
const uint oH = output.sizeAt(dimID + 1);
|
|
const uint oW = output.sizeAt(dimID + 2);
|
|
|
|
const Nd4jLong xStride0 = input.stridesOf()[0];
|
|
const Nd4jLong xStride1 = input.stridesOf()[dimIC];
|
|
const Nd4jLong xStride2 = input.stridesOf()[dimID];
|
|
const Nd4jLong xStride3 = input.stridesOf()[dimID + 1];
|
|
const Nd4jLong xStride4 = input.stridesOf()[dimID + 2];
|
|
|
|
const Nd4jLong zStride0 = output.stridesOf()[0];
|
|
const Nd4jLong zStride1 = output.stridesOf()[dimIC];
|
|
const Nd4jLong zStride2 = output.stridesOf()[dimID];
|
|
const Nd4jLong zStride3 = output.stridesOf()[dimID + 1];
|
|
const Nd4jLong zStride4 = output.stridesOf()[dimID + 2];
|
|
|
|
uint xCoord2, xCoord3, xCoord4;
|
|
// loop through output array
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(collapse(5) private(xCoord2, xCoord3, xCoord4))
|
|
for(uint b = 0; b < bS; ++b) {
|
|
for(uint c = 0; c < iC; ++c) {
|
|
for(uint d = 0; d < oD ; ++d) {
|
|
for(uint h = 0; h < oH ; ++h) {
|
|
for(uint w = 0; w < oW ; ++w) {
|
|
|
|
xCoord2 = d / factorD;
|
|
xCoord3 = h / factorH;
|
|
xCoord4 = w / factorW;
|
|
|
|
z[b*zStride0 + c*zStride1 + d*zStride2 + h*zStride3 + w*zStride4] = x[b*xStride0 + c*xStride1 + xCoord2*xStride2 + xCoord3*xStride3 + xCoord4*xStride4];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////
|
|
template <typename T>
|
|
static void upsampling2dBP_(const NDArray& gradO, NDArray& gradI, const bool isNCHW) {
|
|
// gradO has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC)
|
|
// gradI has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC)
|
|
|
|
gradI.nullify();
|
|
|
|
const T* x = gradO.bufferAsT<T>();
|
|
T* z = gradI.bufferAsT<T>();
|
|
|
|
const uint dimIH = isNCHW ? 2 : 1;
|
|
const uint dimIC = isNCHW ? 1 : 3;
|
|
|
|
const uint bS = gradI.sizeAt(0);
|
|
const uint iC = gradI.sizeAt(dimIC);
|
|
const uint iH = gradI.sizeAt(dimIH);
|
|
const uint iW = gradI.sizeAt(dimIH + 1);
|
|
|
|
const uint factorH = gradO.sizeAt(dimIH) / iH;
|
|
const uint factorW = gradO.sizeAt(dimIH + 1) / iW;
|
|
|
|
const Nd4jLong xStride0 = gradO.stridesOf()[0];
|
|
const Nd4jLong xStride1 = gradO.stridesOf()[dimIC];
|
|
const Nd4jLong xStride2 = gradO.stridesOf()[dimIH];
|
|
const Nd4jLong xStride3 = gradO.stridesOf()[dimIH + 1];
|
|
|
|
const Nd4jLong zStride0 = gradI.stridesOf()[0];
|
|
const Nd4jLong zStride1 = gradI.stridesOf()[dimIC];
|
|
const Nd4jLong zStride2 = gradI.stridesOf()[dimIH];
|
|
const Nd4jLong zStride3 = gradI.stridesOf()[dimIH + 1];
|
|
|
|
// loop through output array
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(collapse(4))
|
|
for(uint b = 0; b < bS; ++b) {
|
|
for(uint c = 0; c < iC; ++c) {
|
|
for(uint h = 0; h < iH; ++h) {
|
|
for(uint w = 0; w < iW; ++w) {
|
|
|
|
const auto zOffset = b*zStride0 + c*zStride1 + h*zStride2 + w*zStride3;
|
|
|
|
for(uint xh = h; xh < h + factorH; ++xh)
|
|
for(uint xw = w; xw < w + factorW; ++xw)
|
|
z[zOffset] += x[b*xStride0 + c*xStride1 + xh*xStride2 + xw*xStride3];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////
|
|
template <typename T>
|
|
static void upsampling3dBP_(const NDArray& gradO, NDArray& gradI, const bool isNCDHW) {
|
|
|
|
// input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC)
|
|
// output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC)
|
|
|
|
gradI.nullify();
|
|
|
|
const T* x = gradO.bufferAsT<T>();
|
|
T* z = gradI.bufferAsT<T>();
|
|
|
|
const uint dimID = isNCDHW ? 2 : 1;
|
|
const uint dimIC = isNCDHW ? 1 : 4;
|
|
|
|
const uint bS = gradI.sizeAt(0);
|
|
const uint iC = gradI.sizeAt(dimIC);
|
|
const uint iD = gradI.sizeAt(dimID);
|
|
const uint iH = gradI.sizeAt(dimID + 1);
|
|
const uint iW = gradI.sizeAt(dimID + 2);
|
|
|
|
const uint factorD = gradO.sizeAt(dimID) / iD;
|
|
const uint factorH = gradO.sizeAt(dimID + 1) / iH;
|
|
const uint factorW = gradO.sizeAt(dimID + 2) / iW;
|
|
|
|
const Nd4jLong xStride0 = gradO.stridesOf()[0];
|
|
const Nd4jLong xStride1 = gradO.stridesOf()[dimIC];
|
|
const Nd4jLong xStride2 = gradO.stridesOf()[dimID];
|
|
const Nd4jLong xStride3 = gradO.stridesOf()[dimID + 1];
|
|
const Nd4jLong xStride4 = gradO.stridesOf()[dimID + 2];
|
|
|
|
const Nd4jLong zStride0 = gradI.stridesOf()[0];
|
|
const Nd4jLong zStride1 = gradI.stridesOf()[dimIC];
|
|
const Nd4jLong zStride2 = gradI.stridesOf()[dimID];
|
|
const Nd4jLong zStride3 = gradI.stridesOf()[dimID + 1];
|
|
const Nd4jLong zStride4 = gradI.stridesOf()[dimID + 2];
|
|
|
|
// loop through output array
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(collapse(5))
|
|
for(uint b = 0; b < bS; ++b) {
|
|
for(uint c = 0; c < iC; ++c) {
|
|
for(uint d = 0; d < iD; ++d) {
|
|
for(uint h = 0; h < iH; ++h) {
|
|
for(uint w = 0; w < iW; ++w) {
|
|
|
|
const auto zOffset = b*zStride0 + c*zStride1 + d*zStride2 + h*zStride3 + w*zStride4;
|
|
|
|
for(uint xd = d; xd < d + factorD; ++xd)
|
|
for(uint xh = h; xh < h + factorH; ++xh)
|
|
for(uint xw = w; xw < w + factorW; ++xw)
|
|
z[zOffset] += x[b*xStride0 + c*xStride1 + xd*xStride2 + xh*xStride3 + xw*xStride4];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////
|
|
template <typename T>
|
|
static void pooling2d_(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
|
// input is [bS, iC, iH, iW]
|
|
// output is [bS, iC, oH, oW]
|
|
T* out = output.bufferAsT<T>();
|
|
T* in = const_cast<NDArray&>(input).bufferAsT<T>();
|
|
|
|
const int kHEff = kH + (kH-1)*(dH-1);
|
|
const int kWEff = kW + (kW-1)*(dW-1);
|
|
|
|
const int bS = input.sizeAt(0);
|
|
const int iC = input.sizeAt(1);
|
|
const int iH = input.sizeAt(2);
|
|
const int iW = input.sizeAt(3);
|
|
const int oC = output.sizeAt(1);
|
|
const int oH = output.sizeAt(2);
|
|
const int oW = output.sizeAt(3);
|
|
|
|
#ifdef HAVE_MKLDNN
|
|
if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported<T, T>()) {
|
|
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
|
|
if (streams.empty()) {
|
|
streams.push_back(MKLDNNStream("pooling2d"));
|
|
}
|
|
|
|
if (streams[0].checkAndReset({&input}, {&output}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0})) {
|
|
mkldnn_memory_desc_t empty;
|
|
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
|
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
|
|
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
mkldnn::algorithm algorithm;
|
|
|
|
ConvolutionUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, true,
|
|
bS, iC, iH, iW, oC, oH, oW, &input, nullptr, &output, algorithm,
|
|
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, &user_dst_md,
|
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
|
|
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, pool_dst_md,
|
|
pool_strides, pool_kernel, pool_padding, pool_padding_r, padding_kind::zero);
|
|
|
|
auto engine = streams[0].getEngine();
|
|
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
auto user_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray&>(input).buffer());
|
|
auto user_dst_memory = mkldnn::memory({user_dst_md, engine}, output.buffer());
|
|
|
|
auto pool_src_memory = user_src_memory;
|
|
streams[0].addMemory(user_src_memory);
|
|
if (mkldnn::memory::primitive_desc(pool_prim_desc.src_primitive_desc())
|
|
!= user_src_memory.get_primitive_desc()) {
|
|
pool_src_memory = mkldnn::memory(pool_prim_desc.src_primitive_desc());
|
|
streams[0].addMemory(pool_src_memory);
|
|
streams[0].addOperation(reorder(user_src_memory, pool_src_memory));
|
|
}
|
|
|
|
auto pool_dst_memory = user_dst_memory;
|
|
streams[0].addMemory(user_dst_memory);
|
|
if (mkldnn::memory::primitive_desc(pool_prim_desc.dst_primitive_desc())
|
|
!= user_dst_memory.get_primitive_desc()) {
|
|
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_primitive_desc());
|
|
streams[0].addMemory(pool_dst_memory);
|
|
}
|
|
|
|
streams[0].addOperation(pooling_forward(pool_prim_desc, pool_src_memory, pool_dst_memory));
|
|
|
|
if (mkldnn::memory::primitive_desc(pool_prim_desc.dst_primitive_desc())
|
|
!= user_dst_memory.get_primitive_desc()) {
|
|
streams[0].addOperation(reorder(pool_dst_memory, user_dst_memory));
|
|
}
|
|
}
|
|
|
|
streams[0].submitAndWait();
|
|
return;
|
|
}
|
|
#endif
|
|
nd4j_debug("MKL-DNN is not used for pooling2d!\n", 0);
|
|
|
|
const Nd4jLong iStride0 = input.stridesOf()[0];
|
|
const Nd4jLong iStride1 = input.stridesOf()[1];
|
|
const Nd4jLong iStride2 = input.stridesOf()[2];
|
|
const Nd4jLong iStride3 = input.stridesOf()[3];
|
|
const Nd4jLong oStride0 = output.stridesOf()[0];
|
|
const Nd4jLong oStride1 = output.stridesOf()[1];
|
|
const Nd4jLong oStride2 = output.stridesOf()[2];
|
|
const Nd4jLong oStride3 = output.stridesOf()[3];
|
|
|
|
const Nd4jLong iStep2 = dH*iStride2;
|
|
const Nd4jLong iStep3 = dW*iStride3;
|
|
const int kProd = kH*kW;
|
|
|
|
Nd4jLong hstart, wstart, hend, wend;
|
|
T *pIn;
|
|
|
|
if(poolingMode == 0) { // max
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, hstart, wstart, hend, wend) collapse(2))
|
|
for(int b = 0; b < bS; ++b) {
|
|
for(int c = 0; c < iC; ++c) {
|
|
for(int oh = 0; oh < oH; ++oh) {
|
|
for(int ow = 0; ow < oW; ++ow) {
|
|
|
|
pIn = in + b * iStride0 + c * iStride1;
|
|
|
|
hstart = oh * sH - pH;
|
|
wstart = ow * sW - pW;
|
|
hend = hstart + kHEff;
|
|
wend = wstart + kWEff;
|
|
|
|
if(hstart < 0)
|
|
hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(-hstart) / static_cast<T>(dH));
|
|
if(wstart < 0)
|
|
wstart += dW * ((-wstart + dW -1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(-wstart) / static_cast<T>(dW));
|
|
if(hend > iH)
|
|
hend -= dH * ((hend-iH + dH - 1) / dH); //(Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(hend-iH) / static_cast<T>(dH));
|
|
if(wend > iW)
|
|
wend -= dW * ((wend-iW + dW - 1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(wend-iW) / static_cast<T>(dW));
|
|
|
|
hstart *= iStride2;
|
|
hend *= iStride2;
|
|
wstart *= iStride3;
|
|
wend *= iStride3;
|
|
|
|
T max = -DataTypeUtils::max<T>();
|
|
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += iStep2)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) {
|
|
T val = pIn[kh + kw];
|
|
if (val > max)
|
|
max = val;
|
|
}
|
|
out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = max;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
/*************************************************************************/
|
|
else if(poolingMode == 1) { // avg
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, hstart, wstart, hend, wend) collapse(2))
|
|
for(int b = 0; b < bS; ++b) {
|
|
for(int c = 0; c < iC; ++c) {
|
|
for(int oh = 0; oh < oH; ++oh) {
|
|
for(int ow = 0; ow < oW; ++ow) {
|
|
|
|
pIn = in + b * iStride0 + c * iStride1;
|
|
|
|
hstart = oh * sH - pH;
|
|
wstart = ow * sW - pW;
|
|
hend = hstart + kHEff;
|
|
wend = wstart + kWEff;
|
|
|
|
if(hstart < 0)
|
|
hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(-hstart) / static_cast<T>(dH));
|
|
if(wstart < 0)
|
|
wstart += dW * ((-wstart + dW -1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(-wstart) / static_cast<T>(dW));
|
|
if(hend > iH)
|
|
hend -= dH * ((hend-iH + dH - 1) / dH); //(Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(hend-iH) / static_cast<T>(dH));
|
|
if(wend > iW)
|
|
wend -= dW * ((wend-iW + dW - 1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(wend-iW) / static_cast<T>(dW));
|
|
|
|
hstart *= iStride2;
|
|
hend *= iStride2;
|
|
wstart *= iStride3;
|
|
wend *= iStride3;
|
|
|
|
T sum = static_cast<T>(0.f);
|
|
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += iStep2)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += iStep3)
|
|
sum += pIn[kh + kw];
|
|
|
|
if (extraParam0 == 0) { //Exclude padding
|
|
int a = (hend-hstart)/iStep2 + ((hend-hstart) % iStep2 == 0 ? 0 : 1);
|
|
int b = (wend-wstart)/iStep3 + ((wend-wstart) % iStep3 == 0 ? 0 : 1);
|
|
sum /= static_cast<T>(a * b); // Accounts for dilation
|
|
}
|
|
else if (extraParam0 == 1) //Include padding
|
|
sum /= kProd;
|
|
|
|
out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
/*************************************************************************/
|
|
else if(poolingMode == 2) { // pnorm
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, hstart, wstart, hend, wend) collapse(2))
|
|
for(int b = 0; b < bS; ++b) {
|
|
for(int c = 0; c < iC; ++c) {
|
|
for(int oh = 0; oh < oH; ++oh) {
|
|
for(int ow = 0; ow < oW; ++ow) {
|
|
|
|
pIn = in + b * iStride0 + c * iStride1;
|
|
|
|
hstart = oh * sH - pH;
|
|
wstart = ow * sW - pW;
|
|
hend = hstart + kHEff;
|
|
wend = wstart + kWEff;
|
|
|
|
if(hstart < 0)
|
|
hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(-hstart) / static_cast<T>(dH));
|
|
if(wstart < 0)
|
|
wstart += dW * ((-wstart + dW -1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(-wstart) / static_cast<T>(dW));
|
|
if(hend > iH)
|
|
hend -= dH * ((hend-iH + dH - 1) / dH); //(Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(hend-iH) / static_cast<T>(dH));
|
|
if(wend > iW)
|
|
wend -= dW * ((wend-iW + dW - 1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(wend-iW) / static_cast<T>(dW));
|
|
|
|
hstart *= iStride2;
|
|
hend *= iStride2;
|
|
wstart *= iStride3;
|
|
wend *= iStride3;
|
|
|
|
T sum = static_cast<T>(0.f);
|
|
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += iStep2)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += iStep3)
|
|
sum += nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(pIn[kh + kw]), extraParam0);
|
|
|
|
sum = nd4j::math::nd4j_pow<T,T,T>(sum, static_cast<T>((T)1.f) / extraParam0);
|
|
|
|
out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else {
|
|
nd4j_printf("ConvolutionUtils::pooling2d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
|
|
throw "";
|
|
}
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////
|
|
template <typename T>
|
|
static void pooling3d_(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
|
// input is [bS, iC, iD, iH, iW]
|
|
// output is [bS, iC, oD, oH, oW]
|
|
T* out = output.bufferAsT<T>();
|
|
T* in = const_cast<NDArray&>(input).bufferAsT<T>();
|
|
|
|
const int kDEff = kD + (kD-1)*(dD-1);
|
|
const int kHEff = kH + (kH-1)*(dH-1);
|
|
const int kWEff = kW + (kW-1)*(dW-1);
|
|
|
|
const int bS = input.sizeAt(0);
|
|
const int iC = input.sizeAt(1);
|
|
const int iD = input.sizeAt(2);
|
|
const int iH = input.sizeAt(3);
|
|
const int iW = input.sizeAt(4);
|
|
const int oC = output.sizeAt(1);
|
|
const int oD = output.sizeAt(2);
|
|
const int oH = output.sizeAt(3);
|
|
const int oW = output.sizeAt(4);
|
|
|
|
#ifdef HAVE_MKLDNN
|
|
if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported<T, T>()) {
|
|
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
|
|
if (streams.empty()) {
|
|
streams.push_back(MKLDNNStream("pooling3d"));
|
|
}
|
|
|
|
if (streams[0].checkAndReset({&input}, {&output}, {}, {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0})) {
|
|
mkldnn_memory_desc_t empty;
|
|
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
|
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
|
|
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
mkldnn::algorithm algorithm;
|
|
|
|
ConvolutionUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0, true,
|
|
bS, iC, iD, iH, iW, oC, oD, oH, oW, &input, nullptr, &output, algorithm,
|
|
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, &user_dst_md,
|
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
|
|
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, pool_dst_md,
|
|
pool_strides, pool_kernel, pool_padding, pool_padding_r, padding_kind::zero);
|
|
|
|
auto engine = streams[0].getEngine();
|
|
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
auto user_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray&>(input).buffer());
|
|
auto user_dst_memory = mkldnn::memory({user_dst_md, engine}, output.buffer());
|
|
|
|
auto pool_src_memory = user_src_memory;
|
|
streams[0].addMemory(user_src_memory);
|
|
if (mkldnn::memory::primitive_desc(pool_prim_desc.src_primitive_desc())
|
|
!= user_src_memory.get_primitive_desc()) {
|
|
pool_src_memory = mkldnn::memory(pool_prim_desc.src_primitive_desc());
|
|
streams[0].addMemory(pool_src_memory);
|
|
streams[0].addOperation(reorder(user_src_memory, pool_src_memory));
|
|
}
|
|
|
|
auto pool_dst_memory = user_dst_memory;
|
|
streams[0].addMemory(user_dst_memory);
|
|
if (mkldnn::memory::primitive_desc(pool_prim_desc.dst_primitive_desc())
|
|
!= user_dst_memory.get_primitive_desc()) {
|
|
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_primitive_desc());
|
|
streams[0].addMemory(pool_dst_memory);
|
|
}
|
|
|
|
streams[0].addOperation(pooling_forward(pool_prim_desc, pool_src_memory, pool_dst_memory));
|
|
|
|
if (mkldnn::memory::primitive_desc(pool_prim_desc.dst_primitive_desc())
|
|
!= user_dst_memory.get_primitive_desc()) {
|
|
streams[0].addOperation(reorder(pool_dst_memory, user_dst_memory));
|
|
}
|
|
}
|
|
|
|
streams[0].submitAndWait();
|
|
return;
|
|
}
|
|
#endif
|
|
nd4j_debug("MKL-DNN is not used for pooling3d!\n", 0);
|
|
|
|
const Nd4jLong iStride0 = input.stridesOf()[0];
|
|
const Nd4jLong iStride1 = input.stridesOf()[1];
|
|
const Nd4jLong iStride2 = input.stridesOf()[2];
|
|
const Nd4jLong iStride3 = input.stridesOf()[3];
|
|
const Nd4jLong iStride4 = input.stridesOf()[4];
|
|
const Nd4jLong oStride0 = output.stridesOf()[0];
|
|
const Nd4jLong oStride1 = output.stridesOf()[1];
|
|
const Nd4jLong oStride2 = output.stridesOf()[2];
|
|
const Nd4jLong oStride3 = output.stridesOf()[3];
|
|
const Nd4jLong oStride4 = output.stridesOf()[4];
|
|
const Nd4jLong iStep2 = dD*iStride2;
|
|
const Nd4jLong iStep3 = dH*iStride3;
|
|
const Nd4jLong iStep4 = dW*iStride4;
|
|
const int kProd = kD*kH*kW;
|
|
|
|
Nd4jLong dstart, hstart, wstart, dend, hend, wend;
|
|
T sum, *pIn;
|
|
|
|
if(poolingMode == 0) { // max
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, sum, dstart, hstart, wstart, dend, hend, wend))
|
|
for(int b = 0; b < bS; ++b) {
|
|
for(int c = 0; c < iC; ++c) {
|
|
for(int od = 0; od < oD; ++od) {
|
|
for(int oh = 0; oh < oH; ++oh) {
|
|
for(int ow = 0; ow < oW; ++ow) {
|
|
|
|
pIn = in + b * iStride0 + c * iStride1;
|
|
|
|
dstart = od * sD - pD;
|
|
hstart = oh * sH - pH;
|
|
wstart = ow * sW - pW;
|
|
dend = dstart + kDEff;
|
|
hend = hstart + kHEff;
|
|
wend = wstart + kWEff;
|
|
|
|
if(dstart < 0)
|
|
dstart += dD * ((-dstart + dD - 1) / dD);
|
|
if(hstart < 0)
|
|
hstart += dH * ((-hstart + dH - 1) / dH);
|
|
if(wstart < 0)
|
|
wstart += dW * ((-wstart + dW - 1) / dW);
|
|
if(dend > iD)
|
|
dend -= dD * ((dend-iD + dD - 1) / dD);
|
|
if(hend > iH)
|
|
hend -= dH * ((hend-iH + dH - 1) / dH);
|
|
if(wend > iW)
|
|
wend -= dW * ((wend-iW + dW - 1) / dW);
|
|
|
|
dstart *= iStride2;
|
|
dend *= iStride2;
|
|
hstart *= iStride3;
|
|
hend *= iStride3;
|
|
wstart *= iStride4;
|
|
wend *= iStride4;
|
|
|
|
sum = -DataTypeUtils::max<T>();
|
|
|
|
for (Nd4jLong kd = dstart; kd < dend; kd += iStep2)
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += iStep3)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) {
|
|
T val = pIn[kd + kh + kw];
|
|
if (val > sum)
|
|
sum = val;
|
|
}
|
|
out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
/*************************************************************************/
|
|
else if(poolingMode == 1) { // avg
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, sum, dstart, hstart, wstart, dend, hend, wend))
|
|
for(int b = 0; b < bS; ++b) {
|
|
for(int c = 0; c < iC; ++c) {
|
|
for(int od = 0; od < oD; ++od) {
|
|
for(int oh = 0; oh < oH; ++oh) {
|
|
for(int ow = 0; ow < oW; ++ow) {
|
|
|
|
pIn = in + b * iStride0 + c * iStride1;
|
|
|
|
dstart = od * sD - pD;
|
|
hstart = oh * sH - pH;
|
|
wstart = ow * sW - pW;
|
|
dend = dstart + kDEff;
|
|
hend = hstart + kHEff;
|
|
wend = wstart + kWEff;
|
|
|
|
if(dstart < 0)
|
|
dstart += dD * ((-dstart + dD - 1) / dD);
|
|
if(hstart < 0)
|
|
hstart += dH * ((-hstart + dH - 1) / dH);
|
|
if(wstart < 0)
|
|
wstart += dW * ((-wstart + dW - 1) / dW);
|
|
if(dend > iD)
|
|
dend -= dD * ((dend-iD + dD - 1) / dD);
|
|
if(hend > iH)
|
|
hend -= dH * ((hend-iH + dH - 1) / dH);
|
|
if(wend > iW)
|
|
wend -= dW * ((wend-iW + dW - 1) / dW);
|
|
|
|
dstart *= iStride2;
|
|
dend *= iStride2;
|
|
hstart *= iStride3;
|
|
hend *= iStride3;
|
|
wstart *= iStride4;
|
|
wend *= iStride4;
|
|
|
|
sum = static_cast<T>(0.);
|
|
|
|
for (Nd4jLong kd = dstart; kd < dend; kd += iStep2)
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += iStep3)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += iStep4)
|
|
sum += pIn[kd + kh + kw];
|
|
|
|
if (extraParam0 == 0) //Exclude padding
|
|
sum /= nd4j::math::nd4j_ceil<double,T>(static_cast<double>(dend-dstart) / static_cast<double>(iStep2)) * nd4j::math::nd4j_ceil<double,T>(static_cast<double>(hend-hstart) / static_cast<double>(iStep3)) * nd4j::math::nd4j_ceil<double,T>(static_cast<double>(wend-wstart) / static_cast<double>(iStep4)); //Accounts for dilation
|
|
else if (extraParam0 == 1) //Include padding
|
|
sum /= kProd;
|
|
|
|
out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
/*************************************************************************/
|
|
else if(poolingMode == 2) { // pnorm
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, sum, dstart, hstart, wstart, dend, hend, wend))
|
|
for(int b = 0; b < bS; ++b) {
|
|
for(int c = 0; c < iC; ++c) {
|
|
for(int od = 0; od < oD; ++od) {
|
|
for(int oh = 0; oh < oH; ++oh) {
|
|
for(int ow = 0; ow < oW; ++ow) {
|
|
|
|
pIn = in + b * iStride0 + c * iStride1;
|
|
|
|
dstart = od * sD - pD;
|
|
hstart = oh * sH - pH;
|
|
wstart = ow * sW - pW;
|
|
dend = dstart + kDEff;
|
|
hend = hstart + kHEff;
|
|
wend = wstart + kWEff;
|
|
|
|
if(dstart < 0)
|
|
dstart += dD * ((-dstart + dD - 1) / dD);
|
|
if(hstart < 0)
|
|
hstart += dH * ((-hstart + dH - 1) / dH);
|
|
if(wstart < 0)
|
|
wstart += dW * ((-wstart + dW - 1) / dW);
|
|
if(dend > iD)
|
|
dend -= dD * ((dend-iD + dD - 1) / dD);
|
|
if(hend > iH)
|
|
hend -= dH * ((hend-iH + dH - 1) / dH);
|
|
if(wend > iW)
|
|
wend -= dW * ((wend-iW + dW - 1) / dW);
|
|
|
|
dstart *= iStride2;
|
|
dend *= iStride2;
|
|
hstart *= iStride3;
|
|
hend *= iStride3;
|
|
wstart *= iStride4;
|
|
wend *= iStride4;
|
|
|
|
sum = static_cast<T>(0.);
|
|
|
|
for (Nd4jLong kd = dstart; kd < dend; kd += iStep2)
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += iStep3)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += iStep4)
|
|
sum += nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(pIn[kd + kh + kw]), extraParam0);
|
|
|
|
sum = nd4j::math::nd4j_pow<T,T,T>(sum, (T) 1.f / extraParam0);
|
|
|
|
out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else {
|
|
nd4j_printf("ConvolutionUtils::pooling3d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
|
|
throw "";
|
|
}
|
|
}
|
|
|
|
|
|
//////////////////////////////////////////////////////////////////////////
|
|
template <typename T>
|
|
static void pooling2dBP_(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
|
// input [bS, iC, iH, iW]
|
|
// gradI [bS, iC, iH, iW] -> gradI is output in this function
|
|
// gradO [bS, iC, oH, oW]
|
|
|
|
// initial zeroing of gradI
|
|
gradI.nullify();
|
|
|
|
T* in = const_cast<NDArray&>(input).bufferAsT<T>();
|
|
T* gO = const_cast<NDArray&>(gradO).bufferAsT<T>();
|
|
T* gI = gradI.bufferAsT<T>();
|
|
|
|
const int kHEff = kH + (kH-1)*(dH-1);
|
|
const int kWEff = kW + (kW-1)*(dW-1);
|
|
|
|
const int bS = gradI.sizeAt(0);
|
|
const int iC = gradI.sizeAt(1);
|
|
const int iH = gradI.sizeAt(2);
|
|
const int iW = gradI.sizeAt(3);
|
|
const int oC = gradO.sizeAt(1);
|
|
const int oH = gradO.sizeAt(2);
|
|
const int oW = gradO.sizeAt(3);
|
|
|
|
#ifdef HAVE_MKLDNN
|
|
if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported<T, T>()) {
|
|
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
|
|
if (streams.empty()) {
|
|
streams.push_back(MKLDNNStream("pooling2d_bp"));
|
|
}
|
|
|
|
if (streams[0].checkAndReset({&input, &gradO}, {&gradI}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0})) {
|
|
mkldnn_memory_desc_t empty;
|
|
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
|
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
|
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
mkldnn::algorithm algorithm;
|
|
|
|
ConvolutionUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, true,
|
|
bS, iC, iH, iW, oC, oH, oW, &input, &gradI, &gradO, algorithm,
|
|
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, &user_diff_src_md, &user_dst_md,
|
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
|
|
// input is sometimes null, so we can't rely on pool_src_md being valid
|
|
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
|
|
const_cast<NDArray&>(input).buffer() != nullptr ? pool_src_md : pool_diff_src_md,
|
|
pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r, padding_kind::zero);
|
|
|
|
auto engine = streams[0].getEngine();
|
|
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
|
|
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
|
|
pool_strides, pool_kernel, pool_padding, pool_padding_r, padding_kind::zero);
|
|
|
|
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
|
auto userB_src_memory = mkldnn::memory({user_src_md, engine}, gradI.buffer());
|
|
auto userB_dst_memory = mkldnn::memory({user_dst_md, engine}, const_cast<NDArray&>(gradO).buffer());
|
|
|
|
auto poolB_src_memory = userB_src_memory;
|
|
streams[0].addMemory(userB_src_memory);
|
|
if (mkldnn::memory::primitive_desc(poolB_prim_desc.diff_src_primitive_desc())
|
|
!= userB_src_memory.get_primitive_desc()) {
|
|
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_primitive_desc());
|
|
streams[0].addMemory(poolB_src_memory);
|
|
}
|
|
|
|
auto poolB_dst_memory = userB_dst_memory;
|
|
streams[0].addMemory(userB_dst_memory);
|
|
if (mkldnn::memory::primitive_desc(poolB_prim_desc.diff_dst_primitive_desc())
|
|
!= userB_dst_memory.get_primitive_desc()) {
|
|
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_primitive_desc());
|
|
streams[0].addMemory(poolB_dst_memory);
|
|
streams[0].addOperation(reorder(userB_dst_memory, poolB_dst_memory));
|
|
}
|
|
|
|
if (algorithm == mkldnn::pooling_max) {
|
|
auto user_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray&>(input).buffer());
|
|
|
|
auto pool_src_memory = user_src_memory;
|
|
streams[0].addMemory(user_src_memory);
|
|
if (mkldnn::memory::primitive_desc(pool_prim_desc.src_primitive_desc())
|
|
!= user_src_memory.get_primitive_desc()) {
|
|
pool_src_memory = mkldnn::memory(pool_prim_desc.src_primitive_desc());
|
|
streams[0].addMemory(pool_src_memory);
|
|
streams[0].addOperation(reorder(user_src_memory, pool_src_memory));
|
|
}
|
|
|
|
auto pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_primitive_desc());
|
|
streams[0].addMemory(pool_dst_memory);
|
|
|
|
auto pool_workspace_memory = mkldnn::memory(pool_prim_desc.workspace_primitive_desc());
|
|
streams[0].addMemory(pool_workspace_memory);
|
|
|
|
streams[0].addOperation(pooling_forward(pool_prim_desc, pool_src_memory, pool_dst_memory, pool_workspace_memory));
|
|
streams[0].addOperation(pooling_backward(poolB_prim_desc, poolB_dst_memory, pool_workspace_memory, poolB_src_memory));
|
|
} else {
|
|
streams[0].addOperation(pooling_backward(poolB_prim_desc, poolB_dst_memory, poolB_src_memory));
|
|
}
|
|
|
|
if (mkldnn::memory::primitive_desc(poolB_prim_desc.diff_src_primitive_desc())
|
|
!= userB_src_memory.get_primitive_desc()) {
|
|
streams[0].addOperation(reorder(poolB_src_memory, userB_src_memory));
|
|
}
|
|
}
|
|
|
|
streams[0].submitAndWait();
|
|
return;
|
|
}
|
|
#endif
|
|
nd4j_debug("MKL-DNN is not used for pooling2d_bp!\n", 0);
|
|
|
|
const Nd4jLong iStride0 = input.stridesOf()[0];
|
|
const Nd4jLong iStride1 = input.stridesOf()[1];
|
|
const Nd4jLong iStride2 = input.stridesOf()[2];
|
|
const Nd4jLong iStride3 = input.stridesOf()[3];
|
|
const Nd4jLong gIStride0 = gradI.stridesOf()[0];
|
|
const Nd4jLong gIStride1 = gradI.stridesOf()[1];
|
|
const Nd4jLong gIStride2 = gradI.stridesOf()[2];
|
|
const Nd4jLong gIStride3 = gradI.stridesOf()[3];
|
|
const Nd4jLong oStride0 = gradO.stridesOf()[0];
|
|
const Nd4jLong oStride1 = gradO.stridesOf()[1];
|
|
const Nd4jLong oStride2 = gradO.stridesOf()[2];
|
|
const Nd4jLong oStride3 = gradO.stridesOf()[3];
|
|
const Nd4jLong iStep2 = dH*iStride2;
|
|
const Nd4jLong iStep3 = dW*iStride3;
|
|
const Nd4jLong gIStep2 = dH*gIStride2;
|
|
const Nd4jLong gIStep3 = dW*gIStride3;
|
|
const int kProd = kH*kW;
|
|
|
|
const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && iStride2 == gIStride2 && iStride3 == gIStride3;
|
|
|
|
Nd4jLong hstart, wstart,hend, wend, maxKH, maxKW;
|
|
T sum, valO, *pIn, *pgI;
|
|
|
|
if(poolingMode == 0) { // max
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, valO, sum, hstart, wstart, hend, wend, maxKH, maxKW))
|
|
for(int b = 0; b < bS; ++b) {
|
|
for(int c = 0; c < iC; ++c) {
|
|
for(int oh = 0; oh < oH; ++oh) {
|
|
for(int ow = 0; ow < oW; ++ow) {
|
|
|
|
pIn = in + b * iStride0 + c * iStride1;
|
|
|
|
hstart = oh * sH - pH;
|
|
wstart = ow * sW - pW;
|
|
hend = hstart + kHEff;
|
|
wend = wstart + kWEff;
|
|
|
|
if(hstart < 0)
|
|
hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(-hstart) / static_cast<T>(dH));
|
|
if(wstart < 0)
|
|
wstart += dW * ((-wstart + dW -1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(-wstart) / static_cast<T>(dW));
|
|
if(hend > iH)
|
|
hend -= dH * ((hend-iH + dH - 1) / dH); //(Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(hend-iH) / static_cast<T>(dH));
|
|
if(wend > iW)
|
|
wend -= dW * ((wend-iW + dW - 1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(wend-iW) / static_cast<T>(dW));
|
|
|
|
sum = -DataTypeUtils::max<T>();
|
|
valO = gO[b*oStride0 + c*oStride1 + oh*oStride2 + ow*oStride3];
|
|
|
|
if(sameStrides) {
|
|
|
|
hstart *= iStride2;
|
|
hend *= iStride2;
|
|
wstart *= iStride3;
|
|
wend *= iStride3;
|
|
|
|
// we set these to default values
|
|
maxKH = hstart;
|
|
maxKW = wstart;
|
|
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += iStep2)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) {
|
|
T valIn = pIn[kh + kw];
|
|
if (valIn > sum) {
|
|
sum = valIn;
|
|
maxKH = kh;
|
|
maxKW = kw;
|
|
}
|
|
}
|
|
gI[pIn - in + maxKH + maxKW] += valO;
|
|
}
|
|
else {
|
|
|
|
// we set these to default values
|
|
maxKH = hstart;
|
|
maxKW = wstart;
|
|
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += dH)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += dW) {
|
|
T valIn = pIn[kh * iStride2 + kw * iStride3];
|
|
if (valIn > sum) {
|
|
sum = valIn;
|
|
maxKH = kh;
|
|
maxKW = kw;
|
|
}
|
|
}
|
|
gI[b * gIStride0 + c * gIStride1 + maxKH * gIStride2 + maxKW * gIStride3] += valO;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
/*************************************************************************/
|
|
else if(poolingMode == 1) { // avg
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pgI, valO, hstart, wstart, hend, wend))
|
|
for(int b = 0; b < bS; ++b) {
|
|
for(int c = 0; c < iC; ++c) {
|
|
for(int oh = 0; oh < oH; ++oh) {
|
|
for(int ow = 0; ow < oW; ++ow) {
|
|
|
|
pgI = gI + b * gIStride0 + c * gIStride1;
|
|
|
|
hstart = oh * sH - pH;
|
|
wstart = ow * sW - pW;
|
|
hend = hstart + kHEff;
|
|
wend = wstart + kWEff;
|
|
|
|
if(hstart < 0)
|
|
hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(-hstart) / static_cast<T>(dH));
|
|
if(wstart < 0)
|
|
wstart += dW * ((-wstart + dW -1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(-wstart) / static_cast<T>(dW));
|
|
if(hend > iH)
|
|
hend -= dH * ((hend-iH + dH - 1) / dH); //(Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(hend-iH) / static_cast<T>(dH));
|
|
if(wend > iW)
|
|
wend -= dW * ((wend-iW + dW - 1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(wend-iW) / static_cast<T>(dW));
|
|
|
|
hstart *= gIStride2;
|
|
hend *= gIStride2;
|
|
wstart *= gIStride3;
|
|
wend *= gIStride3;
|
|
|
|
valO = gO[b*oStride0 + c*oStride1 + oh*oStride2 + ow*oStride3];
|
|
|
|
if ((int) extraParam0 == 0) //Exclude padding
|
|
valO /= static_cast<T>(nd4j::math::nd4j_ceil<double,T>(static_cast<double>(hend-hstart) / static_cast<double>(gIStep2))) * static_cast<T>(nd4j::math::nd4j_ceil<double,T>(static_cast<double>(wend-wstart) / static_cast<double>(gIStep3))); //Accounts for dilation
|
|
else if ((int) extraParam0 == 1) //Include padding
|
|
valO /= kProd;
|
|
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += gIStep2)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += gIStep3)
|
|
pgI[kh + kw] += valO;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
/*************************************************************************/
|
|
else if(poolingMode == 2) { // pnorm
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, valO, pgI, sum, hstart, wstart, hend, wend))
|
|
for(int b = 0; b < bS; ++b) {
|
|
for(int c = 0; c < iC; ++c) {
|
|
for(int oh = 0; oh < oH; ++oh) {
|
|
for(int ow = 0; ow < oW; ++ow) {
|
|
|
|
pIn = in + b * iStride0 + c * iStride1;
|
|
pgI = sameStrides ? gI + (pIn - in) : gI + b * gIStride0 + c * gIStride1;
|
|
|
|
hstart = oh * sH - pH;
|
|
wstart = ow * sW - pW;
|
|
hend = hstart + kHEff;
|
|
wend = wstart + kWEff;
|
|
|
|
if(hstart < 0)
|
|
hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(-hstart) / static_cast<T>(dH));
|
|
if(wstart < 0)
|
|
wstart += dW * ((-wstart + dW -1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(-wstart) / static_cast<T>(dW));
|
|
if(hend > iH)
|
|
hend -= dH * ((hend-iH + dH - 1) / dH); //(Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(hend-iH) / static_cast<T>(dH));
|
|
if(wend > iW)
|
|
wend -= dW * ((wend-iW + dW - 1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil<T,T>(static_cast<T>(wend-iW) / static_cast<T>(dW));
|
|
|
|
sum = static_cast<T>(0.f);
|
|
valO = gO[b*oStride0 + c*oStride1 + oh*oStride2 + ow*oStride3];
|
|
|
|
if(sameStrides) {
|
|
|
|
hstart *= iStride2;
|
|
hend *= iStride2;
|
|
wstart *= iStride3;
|
|
wend *= iStride3;
|
|
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += iStep2)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += iStep3)
|
|
sum += nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(pIn[kh + kw]), extraParam0);
|
|
|
|
valO *= nd4j::math::nd4j_pow<T,T,T>(sum, ((T)1. - extraParam0) / extraParam0);
|
|
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += iStep2)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += iStep3)
|
|
pgI[kh + kw] += valO * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(pIn[kh + kw]), extraParam0 - 1.f) * nd4j::math::nd4j_sgn<T,T>(pIn[kh + kw]);
|
|
}
|
|
else {
|
|
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += dH)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += dW)
|
|
sum += nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(pIn[kh * iStride2 + kw * iStride3]), extraParam0);
|
|
|
|
valO *= nd4j::math::nd4j_pow<T,T,T>(sum, ((T)1. - extraParam0) / extraParam0);
|
|
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += dH) {
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += dW) {
|
|
const auto inVal = pIn[kh * iStride2 + kw * iStride3];
|
|
pgI[kh * gIStride2 + kw * gIStride3] += valO * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(inVal), extraParam0 - 1.f) * nd4j::math::nd4j_sgn<T,T>(inVal);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else {
|
|
nd4j_printf("ConvolutionUtils::pooling2dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
|
|
throw "";
|
|
}
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////
|
|
template <typename T>
|
|
static void pooling3dBP_(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
|
// input [bS, iC, iD, iH, iW]
|
|
// gradI [bS, iC, iD, iH, iW] -> gradI is output in this function
|
|
// gradO [bS, iC, oD, oH, oW]
|
|
|
|
// initial zeroing of gradI
|
|
gradI.nullify();
|
|
|
|
T* in = const_cast<NDArray&>(input).bufferAsT<T>();
|
|
T* gO = const_cast<NDArray&>(gradO).bufferAsT<T>();
|
|
T* gI = gradI.bufferAsT<T>();
|
|
|
|
const int kDEff = kD + (kD-1)*(dD-1);
|
|
const int kHEff = kH + (kH-1)*(dH-1);
|
|
const int kWEff = kW + (kW-1)*(dW-1);
|
|
|
|
const int bS = gradI.sizeAt(0);
|
|
const int iC = gradI.sizeAt(1);
|
|
const int iD = gradI.sizeAt(2);
|
|
const int iH = gradI.sizeAt(3);
|
|
const int iW = gradI.sizeAt(4);
|
|
const int oC = gradO.sizeAt(1);
|
|
const int oD = gradO.sizeAt(2);
|
|
const int oH = gradO.sizeAt(3);
|
|
const int oW = gradO.sizeAt(4);
|
|
|
|
#ifdef HAVE_MKLDNN
|
|
if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported<T, T>()) {
|
|
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
|
|
if (streams.empty()) {
|
|
streams.push_back(MKLDNNStream("pooling3d_bp"));
|
|
}
|
|
|
|
if (streams[0].checkAndReset({&input, &gradO}, {&gradI}, {}, {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0})) {
|
|
mkldnn_memory_desc_t empty;
|
|
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
|
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
|
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
mkldnn::algorithm algorithm;
|
|
|
|
ConvolutionUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0, true,
|
|
bS, iC, iD, iH, iW, oC, oD, oH, oW, &input, &gradI, &gradO, algorithm,
|
|
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, &user_diff_src_md, &user_dst_md,
|
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
|
|
// input is sometimes null, so we can't rely on pool_src_md being valid
|
|
if (const_cast<NDArray&>(input).buffer() == nullptr) {
|
|
pool_src_md = pool_diff_src_md;
|
|
user_src_md = user_diff_src_md;
|
|
}
|
|
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md,
|
|
pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r, padding_kind::zero);
|
|
|
|
auto engine = streams[0].getEngine();
|
|
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
|
|
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
|
|
pool_strides, pool_kernel, pool_padding, pool_padding_r, padding_kind::zero);
|
|
|
|
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
|
auto userB_src_memory = mkldnn::memory({user_diff_src_md, engine}, gradI.buffer());
|
|
auto userB_dst_memory = mkldnn::memory({user_dst_md, engine}, const_cast<NDArray&>(gradO).buffer());
|
|
|
|
auto poolB_src_memory = userB_src_memory;
|
|
streams[0].addMemory(userB_src_memory);
|
|
if (mkldnn::memory::primitive_desc(poolB_prim_desc.diff_src_primitive_desc())
|
|
!= userB_src_memory.get_primitive_desc()) {
|
|
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_primitive_desc());
|
|
streams[0].addMemory(poolB_src_memory);
|
|
}
|
|
|
|
auto poolB_dst_memory = userB_dst_memory;
|
|
streams[0].addMemory(userB_dst_memory);
|
|
if (mkldnn::memory::primitive_desc(poolB_prim_desc.diff_dst_primitive_desc())
|
|
!= userB_dst_memory.get_primitive_desc()) {
|
|
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_primitive_desc());
|
|
streams[0].addMemory(poolB_dst_memory);
|
|
streams[0].addOperation(reorder(userB_dst_memory, poolB_dst_memory));
|
|
}
|
|
|
|
if (algorithm == mkldnn::pooling_max) {
|
|
auto user_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray&>(input).buffer());
|
|
|
|
auto pool_src_memory = user_src_memory;
|
|
streams[0].addMemory(user_src_memory);
|
|
if (mkldnn::memory::primitive_desc(pool_prim_desc.src_primitive_desc())
|
|
!= user_src_memory.get_primitive_desc()) {
|
|
pool_src_memory = mkldnn::memory(pool_prim_desc.src_primitive_desc());
|
|
streams[0].addMemory(pool_src_memory);
|
|
streams[0].addOperation(reorder(user_src_memory, pool_src_memory));
|
|
}
|
|
|
|
auto pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_primitive_desc());
|
|
streams[0].addMemory(pool_dst_memory);
|
|
|
|
auto pool_workspace_memory = mkldnn::memory(pool_prim_desc.workspace_primitive_desc());
|
|
streams[0].addMemory(pool_workspace_memory);
|
|
|
|
streams[0].addOperation(pooling_forward(pool_prim_desc, pool_src_memory, pool_dst_memory, pool_workspace_memory));
|
|
streams[0].addOperation(pooling_backward(poolB_prim_desc, poolB_dst_memory, pool_workspace_memory, poolB_src_memory));
|
|
} else {
|
|
streams[0].addOperation(pooling_backward(poolB_prim_desc, poolB_dst_memory, poolB_src_memory));
|
|
}
|
|
|
|
if (mkldnn::memory::primitive_desc(poolB_prim_desc.diff_src_primitive_desc())
|
|
!= userB_src_memory.get_primitive_desc()) {
|
|
streams[0].addOperation(reorder(poolB_src_memory, userB_src_memory));
|
|
}
|
|
}
|
|
|
|
streams[0].submitAndWait();
|
|
return;
|
|
}
|
|
#endif
|
|
nd4j_debug("MKL-DNN is not used for pooling3d_bp!\n", 0);
|
|
|
|
const Nd4jLong iStride0 = input.stridesOf()[0];
|
|
const Nd4jLong iStride1 = input.stridesOf()[1];
|
|
const Nd4jLong iStride2 = input.stridesOf()[2];
|
|
const Nd4jLong iStride3 = input.stridesOf()[3];
|
|
const Nd4jLong iStride4 = input.stridesOf()[4];
|
|
const Nd4jLong gIStride0 = gradI.stridesOf()[0];
|
|
const Nd4jLong gIStride1 = gradI.stridesOf()[1];
|
|
const Nd4jLong gIStride2 = gradI.stridesOf()[2];
|
|
const Nd4jLong gIStride3 = gradI.stridesOf()[3];
|
|
const Nd4jLong gIStride4 = gradI.stridesOf()[4];
|
|
const Nd4jLong oStride0 = gradO.stridesOf()[0];
|
|
const Nd4jLong oStride1 = gradO.stridesOf()[1];
|
|
const Nd4jLong oStride2 = gradO.stridesOf()[2];
|
|
const Nd4jLong oStride3 = gradO.stridesOf()[3];
|
|
const Nd4jLong oStride4 = gradO.stridesOf()[4];
|
|
const Nd4jLong iStep2 = dD*iStride2;
|
|
const Nd4jLong iStep3 = dH*iStride3;
|
|
const Nd4jLong iStep4 = dW*iStride4;
|
|
const Nd4jLong gIStep2 = dD*gIStride2;
|
|
const Nd4jLong gIStep3 = dH*gIStride3;
|
|
const Nd4jLong gIStep4 = dW*gIStride4;
|
|
const int kProd = kD*kH*kW;
|
|
|
|
const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && iStride2 == gIStride2 && iStride3 == gIStride3 && iStride4 == gIStride4;
|
|
|
|
Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW;
|
|
T sum, valO, *pIn, *pgI;
|
|
|
|
if(poolingMode == 0) { // max
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, valO, sum, dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW))
|
|
for(int b = 0; b < bS; ++b) {
|
|
for(int c = 0; c < iC; ++c) {
|
|
for(int od = 0; od < oD; ++od) {
|
|
for(int oh = 0; oh < oH; ++oh) {
|
|
for(int ow = 0; ow < oW; ++ow) {
|
|
|
|
pIn = in + b * iStride0 + c * iStride1;
|
|
|
|
dstart = od * sD - pD;
|
|
hstart = oh * sH - pH;
|
|
wstart = ow * sW - pW;
|
|
dend = dstart + kDEff;
|
|
hend = hstart + kHEff;
|
|
wend = wstart + kWEff;
|
|
|
|
if(dstart < 0)
|
|
dstart += dD * ((-dstart + dD - 1) / dD);
|
|
if(hstart < 0)
|
|
hstart += dH * ((-hstart + dH - 1) / dH);
|
|
if(wstart < 0)
|
|
wstart += dW * ((-wstart + dW - 1) / dW);
|
|
if(dend > iD)
|
|
dend -= dD * ((dend-iD + dD - 1) / dD);
|
|
if(hend > iH)
|
|
hend -= dH * ((hend-iH + dH - 1) / dH);
|
|
if(wend > iW)
|
|
wend -= dW * ((wend-iW + dW - 1) / dW);
|
|
|
|
sum = -DataTypeUtils::max<T>();
|
|
valO = gO[b*oStride0 + c*oStride1+ od*oStride2 + oh*oStride3 + ow*oStride4];
|
|
|
|
if(sameStrides) {
|
|
|
|
dstart *= iStride2;
|
|
dend *= iStride2;
|
|
hstart *= iStride3;
|
|
hend *= iStride3;
|
|
wstart *= iStride4;
|
|
wend *= iStride4;
|
|
|
|
maxKD = dstart;
|
|
maxKH = hstart;
|
|
maxKW = wstart;
|
|
|
|
for (Nd4jLong kd = dstart; kd < dend; kd += iStep2)
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += iStep3)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) {
|
|
T valIn = pIn[kd + kh + kw];
|
|
if (valIn > sum) {
|
|
sum = valIn;
|
|
maxKD = kd;
|
|
maxKH = kh;
|
|
maxKW = kw;
|
|
}
|
|
}
|
|
gI[pIn - in + maxKD + maxKH + maxKW] += valO;
|
|
}
|
|
else {
|
|
|
|
// we set these to default values
|
|
maxKH = hstart;
|
|
maxKW = wstart;
|
|
maxKD = dstart;
|
|
|
|
for (Nd4jLong kd = dstart; kd < dend; kd += dD)
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += dH)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += dW) {
|
|
T valIn = pIn[kd * iStride2 + kh * iStride3 + kw * iStride4];
|
|
if (valIn > sum) {
|
|
sum = valIn;
|
|
maxKD = kd;
|
|
maxKH = kh;
|
|
maxKW = kw;
|
|
}
|
|
}
|
|
gI[b * gIStride0 + c * gIStride1 + maxKD * gIStride2 + maxKH * gIStride3 + maxKW * gIStride4] += valO;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
/*************************************************************************/
|
|
else if(poolingMode == 1) { // avg
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pgI, valO, dstart, hstart, wstart, dend, hend, wend))
|
|
for(int b = 0; b < bS; ++b) {
|
|
for(int c = 0; c < iC; ++c) {
|
|
for(int od = 0; od < oD; ++od) {
|
|
for(int oh = 0; oh < oH; ++oh) {
|
|
for(int ow = 0; ow < oW; ++ow) {
|
|
|
|
pgI = gI + b * gIStride0 + c * gIStride1;
|
|
|
|
dstart = od * sD - pD;
|
|
hstart = oh * sH - pH;
|
|
wstart = ow * sW - pW;
|
|
dend = dstart + kDEff;
|
|
hend = hstart + kHEff;
|
|
wend = wstart + kWEff;
|
|
|
|
if(dstart < 0)
|
|
dstart += dD * ((-dstart + dD - 1) / dD);
|
|
if(hstart < 0)
|
|
hstart += dH * ((-hstart + dH - 1) / dH);
|
|
if(wstart < 0)
|
|
wstart += dW * ((-wstart + dW - 1) / dW);
|
|
if(dend > iD)
|
|
dend -= dD * ((dend-iD + dD - 1) / dD);
|
|
if(hend > iH)
|
|
hend -= dH * ((hend-iH + dH - 1) / dH);
|
|
if(wend > iW)
|
|
wend -= dW * ((wend-iW + dW - 1) / dW);
|
|
|
|
dstart *= gIStride2;
|
|
dend *= gIStride2;
|
|
hstart *= gIStride3;
|
|
hend *= gIStride3;
|
|
wstart *= gIStride4;
|
|
wend *= gIStride4;
|
|
|
|
valO = gO[b*oStride0 + c*oStride1+ od*oStride2 + oh*oStride3 + ow*oStride4];
|
|
|
|
if (extraParam0 == 0) //Exclude padding
|
|
valO /= nd4j::math::nd4j_ceil<double,T>(static_cast<double>(dend-dstart) / static_cast<double>(gIStep2)) * nd4j::math::nd4j_ceil<double,T>(static_cast<double>(hend-hstart) / static_cast<double>(gIStep3)) * nd4j::math::nd4j_ceil<double,T>(static_cast<double>(wend-wstart) / static_cast<double>(gIStep4)); //Accounts for dilation
|
|
else if (extraParam0 == 1) //Include padding
|
|
valO /= kProd;
|
|
|
|
for (Nd4jLong kd = dstart; kd < dend; kd += gIStep2)
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += gIStep3)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += gIStep4)
|
|
pgI[kd + kh + kw] += valO;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
/*************************************************************************/
|
|
else if(poolingMode == 2) { // pnorm
|
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, pgI, valO, sum, dstart, hstart, wstart, dend, hend, wend))
|
|
for(int b = 0; b < bS; ++b) {
|
|
for(int c = 0; c < iC; ++c) {
|
|
for(int od = 0; od < oD; ++od) {
|
|
for(int oh = 0; oh < oH; ++oh) {
|
|
for(int ow = 0; ow < oW; ++ow) {
|
|
|
|
pIn = in + b * iStride0 + c * iStride1;
|
|
pgI = gI + (pIn - in);
|
|
|
|
dstart = od * sD - pD;
|
|
hstart = oh * sH - pH;
|
|
wstart = ow * sW - pW;
|
|
dend = dstart + kDEff;
|
|
hend = hstart + kHEff;
|
|
wend = wstart + kWEff;
|
|
|
|
if(dstart < 0)
|
|
dstart += dD * ((-dstart + dD - 1) / dD);
|
|
if(hstart < 0)
|
|
hstart += dH * ((-hstart + dH - 1) / dH);
|
|
if(wstart < 0)
|
|
wstart += dW * ((-wstart + dW - 1) / dW);
|
|
if(dend > iD)
|
|
dend -= dD * ((dend-iD + dD - 1) / dD);
|
|
if(hend > iH)
|
|
hend -= dH * ((hend-iH + dH - 1) / dH);
|
|
if(wend > iW)
|
|
wend -= dW * ((wend-iW + dW - 1) / dW);
|
|
|
|
sum = static_cast<T>(0.);
|
|
valO = gO[b*oStride0 + c*oStride1+ od*oStride2 + oh*oStride3 + ow*oStride4];
|
|
|
|
if(sameStrides) {
|
|
|
|
dstart *= iStride2;
|
|
dend *= iStride2;
|
|
hstart *= iStride3;
|
|
hend *= iStride3;
|
|
wstart *= iStride4;
|
|
wend *= iStride4;
|
|
|
|
for (Nd4jLong kd = dstart; kd < dend; kd += iStep2)
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += iStep3)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += iStep4)
|
|
sum += nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(pIn[kd + kh + kw]), extraParam0);
|
|
|
|
valO *= nd4j::math::nd4j_pow<T,T,T>(sum, ((T)1.f - extraParam0) / extraParam0);
|
|
|
|
for (Nd4jLong kd = dstart; kd < dend; kd += iStep2)
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += iStep3)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += iStep4)
|
|
pgI[kd + kh + kw] += valO * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(pIn[kd + kh + kw]), extraParam0 - (T)1.f);
|
|
}
|
|
else {
|
|
|
|
for (Nd4jLong kd = dstart; kd < dend; kd += dD)
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += dH)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += dW)
|
|
sum += nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(pIn[kd * iStride2 + kh * iStride3 + kw * iStride4]), extraParam0);
|
|
|
|
valO *= nd4j::math::nd4j_pow<T,T,T>(sum, ((T)1.f - extraParam0) / extraParam0);
|
|
|
|
for (Nd4jLong kd = dstart; kd < dend; kd += dD)
|
|
for (Nd4jLong kh = hstart; kh < hend; kh += dH)
|
|
for (Nd4jLong kw = wstart; kw < wend; kw += dW) {
|
|
const auto inVal = pIn[kD * iStride2 + kh * iStride3 + kw * iStride4];
|
|
pgI[kd * gIStride2 + kh * gIStride3 + kw * gIStride4] += valO * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(inVal), extraParam0 - 1.f) * nd4j::math::nd4j_sgn<T,T>(inVal);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else {
|
|
nd4j_printf("ConvolutionUtils::pooling3dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
|
|
throw "";
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
void ConvolutionUtils::conv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
|
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
|
}
|
|
void ConvolutionUtils::conv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
|
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
|
}
|
|
void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
|
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
|
}
|
|
void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
|
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
|
}
|
|
void ConvolutionUtils::sconv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
|
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
|
}
|
|
void ConvolutionUtils::vol2col(nd4j::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
|
|
BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES);
|
|
}
|
|
void ConvolutionUtils::col2vol(nd4j::graph::Context& block, const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
|
|
BUILD_SINGLE_SELECTOR(volume.dataType(), col2vol_, (columns, volume, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES);
|
|
}
|
|
void ConvolutionUtils::upsampling2d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) {
|
|
BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2d_, (input, output, factorH, factorW, isNCHW), FLOAT_TYPES);
|
|
}
|
|
void ConvolutionUtils::upsampling3d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) {
|
|
BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3d_, (input, output, factorD, factorH, factorW, isNCDHW), FLOAT_TYPES);
|
|
}
|
|
void ConvolutionUtils::upsampling2dBP(nd4j::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) {
|
|
BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling2dBP_, (gradO, gradI, isNCHW), FLOAT_TYPES);
|
|
}
|
|
void ConvolutionUtils::upsampling3dBP(nd4j::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) {
|
|
BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling3dBP_, (gradO, gradI, isNCHW), FLOAT_TYPES);
|
|
}
|
|
|
|
|
|
|
|
void ConvolutionUtils::pooling2d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) {
|
|
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
|
|
}
|
|
void ConvolutionUtils::pooling3d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
|
BUILD_SINGLE_SELECTOR(input.dataType(), pooling3d_, (block, input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
|
|
}
|
|
void ConvolutionUtils::pooling2dBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
|
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
|
|
}
|
|
void ConvolutionUtils::pooling3dBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
|
BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBP_, (block, input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
|
|
}
|
|
}
|
|
} |