cavis/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp

1853 lines
92 KiB
C++
Raw Normal View History

/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
// implementation of operation for LSTM cell with peep hole connections:
// http://www.bioinf.jku.at/publications/older/2604.pdf
// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
// and
// https://research.google.com/pubs/archive/43905.pdf
// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014.
#include <ops/declarable/helpers/lstmLayer.h>
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
#include <execution/Threads.h>
#include <ops/declarable/helpers/activations.h>
#include <helpers/ShapeUtils.h>
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
#include <helpers/MmulHelper.h>
// #include <VariableSpace.h>
// #include <ops/declarable/CustomOperations.h>
// #include<ops/declarable/helpers/transforms.h>
// #include <ops/declarable/helpers/legacy_helpers.h>
// #include <array/NDArrayList.h>
// #include <iterator>
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
namespace sd {
namespace ops {
namespace helpers {
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
//////////////////////////////////////////////////////////////////////////
static void applyActivation(const NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) {
switch (opId) {
case 0:
(const_cast<NDArray&>(x)).applyTransform(transform::Tanh, z);
break;
case 1:
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::RELU, 0, z);
break;
case 2:
(const_cast<NDArray&>(x)).applyTransform(transform::Sigmoid, z);
break;
case 3: {
ExtraArguments args({ static_cast<double>(alpha), static_cast<double>(beta)});
(const_cast<NDArray&>(x)).applyTransform(transform::Affine, z, &args);
break;
}
case 4:
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::LeakyRELU, alpha, z);
break;
case 5:
thresholdRelu(x.getContext(), x, alpha, z);
break;
case 6: {
ExtraArguments args({ static_cast<double>(alpha), static_cast<double>(beta)});
(const_cast<NDArray&>(x)).applyTransform(transform::ScaledTanh, z, &args);
break;
}
case 7:
(const_cast<NDArray&>(x)).applyTransform(transform::HardSigmoid, z);
break;
case 8:
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::ELU, alpha, z);
break;
case 9:
(const_cast<NDArray&>(x)).applyTransform(transform::SoftSign, z);
break;
case 10:
(const_cast<NDArray&>(x)).applyTransform(transform::SoftPlus, z);
break;
default:
throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !");
}
}
//////////////////////////////////////////////////////////////////////////
static void activationDeriv(const NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) {
switch (opId) {
case 0:
(const_cast<NDArray&>(x)).applyTransform(transform::TanhDerivative, z);
break;
case 1:
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::RELUDerivative, 0, z);
break;
case 2:
(const_cast<NDArray&>(x)).applyTransform(transform::SigmoidDerivative, z);
break;
case 3: {
z = alpha;
break;
}
case 4:
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::LeakyRELUDerivative, alpha, z);
break;
case 5:
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::RELUDerivative, alpha, z);
break;
case 6: {
auto func = PRAGMA_THREADS_FOR {
for(Nd4jLong i = start; i < stop; ++i) {
auto val = beta * x.e<float>(i);
z.p<float>(i, alpha * beta * (1.f - sd::math::nd4j_tanh<float,float>(val) * sd::math::nd4j_tanh<float,float>(val)));
}
};
samediff::Threads::parallel_for(func, 0, x.lengthOf());
break;
}
case 7:
(const_cast<NDArray&>(x)).applyTransform(transform::HardSigmoidDerivative, z);
break;
case 8:
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::ELUDerivative, alpha, z);
break;
case 9:
(const_cast<NDArray&>(x)).applyTransform(transform::SoftSignDerivative, z);
break;
case 10: {
auto func = PRAGMA_THREADS_FOR {
for(Nd4jLong i = start; i < stop; ++i) {
auto val = sd::math::nd4j_exp<float, float>(x.e<float>(i));
z.p<float>(i, val / (1.f + val));
}
};
samediff::Threads::parallel_for(func, 0, x.lengthOf());
break;
}
default:
throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !");
}
}
//////////////////////////////////////////////////////////////////////////
// FIXME - derivative undefined when not-clipped c has element/elements equal to -clipVal or clipVal
static void clipDeriv(const float clipVal, const NDArray& c, NDArray& z0, NDArray& z1, NDArray& z2, NDArray& z3) {
if(clipVal == 0)
return;
auto func = PRAGMA_THREADS_FOR {
for(Nd4jLong i = start; i < stop; ++i) {
const auto val = c.e<float>(i);
if(val == -clipVal || val == clipVal) {
z0.p<float>(i, 0.f);
z1.p<float>(i, 0.f);
z2.p<float>(i, 0.f);
z3.p<float>(i, 0.f);
}
}
};
samediff::Threads::parallel_for(func, 0, c.lengthOf());
}
//////////////////////////////////////////////////////////////////////////
static NDArray tensorAlongTimeBatchDims(const NDArray& arr, const int dataFormat, const int t1, const int t2, const int b1, const int b2) {
if(dataFormat == 0 || dataFormat == 3)
return arr({t1,t2, b1,b2, 0,0}); // TNS: [sL, bS, nIn]
if(dataFormat == 1)
return arr({b1,b2, t1,t2, 0,0}); // NTS: [bS, sL ,nIn]
return arr({b1,b2, 0,0, t1,t2}); // NST: [bS, nIn, sL]
}
//////////////////////////////////////////////////////////////////////////
static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL, const int bS, const int t, const int b) {
if(dataFormat == 0 || dataFormat == 3)
return t * bS + b; // TNS: shape [sL, bS, nIn]
return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL]
}
//////////////////////////////////////////////////////////////////////////
// x{M,K} x y{K,N} = z{M,N}, dzdy{K,N,M,N} - Jacobian derivative -> if x.rankOf() == 2
// x{K} x y{K,N} = z{N}, dzdy{K,N,N} - Jacobian derivative -> if x.rankOf() == 1
static NDArray mmulJacobianWeightsDeriv(const int nOut, const NDArray& x) {
std::vector<Nd4jLong> outShape = x.rankOf() == 1 ? std::vector<Nd4jLong>({x.sizeAt(0), nOut, nOut}) : std::vector<Nd4jLong>({x.sizeAt(1), nOut, x.sizeAt(0), nOut});
NDArray dzdy(x.ordering(), outShape, x.dataType(), x.getContext());
if(x.rankOf() == 1) {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
if(i1 == i2)
dzdy.p<double>(i0,i1,i2, x.e<double>(i0));
else
dzdy.p<double>(i0,i1,i2, 0);
}
}
}
};
samediff::Threads::parallel_for(func, 0,dzdy.sizeAt(0),1, 0,dzdy.sizeAt(1),1, 0,dzdy.sizeAt(2),1);
}
else {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
for (auto i3 = 0; i3 < dzdy.sizeAt(3); ++i3) {
if(i1 == i3)
dzdy.p<double>(i0,i1,i2,i3, x.e<double>(i2,i0));
else
dzdy.p<double>(i0,i1,i2,i3, 0);
}
}
}
}
};
samediff::Threads::parallel_for(func, 0,dzdy.sizeAt(0),1, 0,dzdy.sizeAt(1),1, 0,dzdy.sizeAt(2),1);
}
return dzdy;
}
//////////////////////////////////////////////////////////////////////////
void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
const std::vector<float>& params,
NDArray* h, NDArray* c) {
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
// * -> means element-wise multiplication
// ^ -> means matrix multiplication
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/
/** the objective is to provide math-readable code **/
// equations (no peephole connections)
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
// it = σ(Wxi ^ xt + Wri ^ ht-1 + bi)
// ft = σ(Wxf ^ xt + Wrf ^ ht-1 + bf)
// c't = tanh(Wxc ^ xt + Wrc ^ ht-1 + bc)
// ct = ft * ct-1 + it * c't
// ot = σ(Wxo ^ xt + Wro ^ ht-1 + bo)
// ht = ot * tanh(ct)
// equations (peephole connections are present)
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
// it = σ(Wxi ^ xt + Wri ^ ht-1 + Wpi * ct-1 + bi)
// ft = σ(Wxf ^ xt + Wrf ^ ht-1 + Wpf * ct-1 + bf)
// c't = tanh(Wxc ^ xt + Wrc ^ ht-1 + bc)
// ct = ft * ct-1 + it * c't
// ot = σ(Wxo ^ xt + Wro ^ ht-1 + Wpo * ct + bo)
// ht = ot * tanh(ct)
// IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
// params[0] - dataFormat, ignore
// params[1] - directionMode, ignore
// params[2] - cell clipping value, if it = 0 then do not apply clipping
// params[3] - activation ID for input (i), forget (f) and output (o) gates
// params[4] - alpha value for gates activation
// params[5] - beta value for gates activation
// params[6] - activation ID for cell state (c)
// params[7] - alpha value for cell state activation
// params[8] - beta value for cell state activation
// params[9] - activation ID for output (h)
// params[10] - alpha value for output activation
// params[11] - beta value for output activation
// INPUTS:
// x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr
// Wx - input weights [nIn, 4*nOut]
// Wr - recurrent weights [nOut, 4*nOut]
// b - biases [4*nOut], optional, may be nullptr
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
// hI - (ht-1) previous (initial) output at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
// cI - (ct-1) previous (initial) cell state at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
// Wp - peephole weights [3*nOut], optional, may be nullptr
// OUTPUTS:
// h - current output, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr
// c - current cell state, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr
// !!! dimension 4*nOut implies order it, ft, c't, ot
// !!! dimension 3*nOut implies order it, ft, ot
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
auto z = mmul(*x, *Wx) + mmul(*hI, *Wr); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut]
//or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut]
// add biases if they are given
if(b != nullptr)
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
z += *b; // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut]
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut])
auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut])
auto zg = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut])
auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut])
// peephole connections for input and forget gates
if(Wp != nullptr) {
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
}
applyActivation(zi, params[3], params[4], params[5], zi); // inplace
applyActivation(zf, params[3], params[4], params[5], zf); // inplace
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
applyActivation(zg, params[6], params[7], params[8], zg); // inplace
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
c->assign(zf * *cI + zi * zg); // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut])
// if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation
if(params[2] != 0)
Shyrma temp (#131) * - specifying template instantiation for certain types in float16 and bloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing bfloat16 and float16 member functions template specialization Signed-off-by: Yurii <iuriish@yahoo.com> * - rewrite and overload array +-*/ scalar and scalar +-*/ arr in NDAray class Signed-off-by: Yurii <iuriish@yahoo.com> * - make corrections which have to do with and rvalue lvalue conversions Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantic in NDArray operators array +-/* array Signed-off-by: Yurii <iuriish@yahoo.com> * float16/bfloat16 tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more tweak Signed-off-by: raver119 <raver119@gmail.com> * - make float16 and bfloat16 to compile successfully on cuda Signed-off-by: Yurii <iuriish@yahoo.com> * - do not use resources of view-like arrays when move semantics is applied Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of pointers in signatures NDArray methods 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::dup method Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::reduceAlongDimension method Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyIndexReduce and applyTrueBroadcast methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyReduce3 and varianceAlongDimension methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tensorsAlongDimension and diagonal methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::allTensorsAlongDimension Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduceAlongDimension 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyPairwiseTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTrueBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyScalar and applyScalarArr Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::lambda methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduce3 methods 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of following NDArray methods: add/sub/mul/div row/column and fillAsTriangular Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tileToShape methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::isShapeSameStrict method Signed-off-by: Yurii <iuriish@yahoo.com> * minor corrections in tests Signed-off-by: Yurii <iuriish@yahoo.com> * - replace reduce op in batchnorm mkldnn Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit templates instantiations for operator+(NDArray&&. const scalar) Signed-off-by: Yurii <iuriish@yahoo.com> * - corrections of casts in float16/bfloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantics in following NDArray methods: transform, applyTrueBroadcast, transpose, reshape, permute Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of input array A duplicate in svd cuda op Signed-off-by: Yurii <iuriish@yahoo.com> * - avoid available bug in svd cuda API Signed-off-by: Yurii <iuriish@yahoo.com> * - add temporary global memory buffer in svd cuda when calcUV = false and m != n Signed-off-by: Yurii <iuriish@yahoo.com> * - remove test with blfoat16 type for betainC Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts after master has been merged in Signed-off-by: Yurii <iuriish@yahoo.com> * - changed type of affected input array in fused_batch_norm Signed-off-by: Yurii <iuriish@yahoo.com> * - add several explicit type castings Signed-off-by: Yurii <iuriish@yahoo.com> * - add ND4J_EXPORT to operators Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit template types in instantiations of template arithm operators of NDArray class Signed-off-by: Yurii <iuriish@yahoo.com> * - one more test fix Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>
2019-12-20 20:35:39 +01:00
c->applyScalar(scalar::LstmClip, params[2], *c);
// peephole connections for output gate
if(Wp != nullptr)
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
applyActivation(zo, params[3], params[4], params[5], zo);
applyActivation(*c, params[9], params[10], params[11], *h);
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
*h *= zo; // [bS, nOut] * [bS, nOut](or[nOut])
}
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
//////////////////////////////////////////////////////////////////////////
// this auxiliary ff should be running before backprop
void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
const std::vector<float>& params,
NDArray* z, NDArray* a, NDArray* h, NDArray* c) {
// z - zi, zf, zg, zo
// a - i, f, g, o
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
z->assign(mmul(*x, *Wx) + mmul(*hI, *Wr)); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut]
//or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut]
// add biases if they are given
if(b != nullptr)
*z += *b; // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut]
auto zi = x->rankOf() == 1 ? (*z)({0, nOut}) : (*z)({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut])
auto zf = x->rankOf() == 1 ? (*z)({nOut, 2*nOut}) : (*z)({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut])
auto zg = x->rankOf() == 1 ? (*z)({2*nOut, 3*nOut}) : (*z)({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut])
auto zo = x->rankOf() == 1 ? (*z)({3*nOut, 4*nOut}) : (*z)({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut])
auto i = x->rankOf() == 1 ? (*a)({0, nOut}) : (*a)({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut])
auto f = x->rankOf() == 1 ? (*a)({nOut, 2*nOut}) : (*a)({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut])
auto g = x->rankOf() == 1 ? (*a)({2*nOut, 3*nOut}) : (*a)({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut])
auto o = x->rankOf() == 1 ? (*a)({3*nOut, 4*nOut}) : (*a)({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut])
// peephole connections for input and forget gates
if(Wp != nullptr) {
zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
}
applyActivation(zi, params[3], params[4], params[5], i);
applyActivation(zf, params[3], params[4], params[5], f);
applyActivation(zg, params[6], params[7], params[8], g);
c->assign(f * *cI + i * g); // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut])
// if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation
if(params[2] != 0)
c->applyScalar(scalar::LstmClip, params[2], *c);
// peephole connections for output gate
if(Wp != nullptr)
zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
applyActivation(zo, params[3], params[4], params[5], o);
applyActivation(*c, params[9], params[10], params[11], *h);
*h *= o; // [bS, nOut] * [bS, nOut](or[nOut])
}
//////////////////////////////////////////////////////////////////////////
void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
const NDArray* dLdh, const NDArray* dLdc,
const NDArray* z, const NDArray* a, const NDArray* c, const std::vector<float>& params,
NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) {
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/
/** the objective is to provide math-readable code **/
// equations (no peephole connections)
// zi = x ^ Wxi + hI ^ Wri + bi
// zf = x ^ Wxf + hI ^ Wrf + bf
// zg = x ^ Wxg + hI ^ Wrg + bg
// zo = x ^ Wxo + hI ^ Wro + bo
// i = act(zi)
// f = act(zf)
// g = actC(zg)
// o = act(zo)
// c = clip(f * cI + i * g)
// h = o * actH(c)
// equations (peephole connections are present)
// zi = x ^ Wxi + hI ^ Wri + cI * Wpi + bi
// zf = x ^ Wxf + hI ^ Wrf + cI * Wpf + bf
// zg = x ^ Wxg + hI ^ Wrg + bg
// zo = x ^ Wxo + hI ^ Wro + c * Wpo + bo
// i = act(zi)
// f = act(zf)
// g = actC(zg)
// o = act(zo)
// c = clip(f * cI + i * g)
// h = o * actH(c)
// IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
// params[0] - dataFormat, ignore
// params[1] - directionMode, ignore
// params[2] - cell clipping value, if it = 0 then do not apply clipping
// params[3] - activation ID for input (i), forget (f) and output (o) gates
// params[4] - alpha value for gates activation
// params[5] - beta value for gates activation
// params[6] - activation ID for cell state (c)
// params[7] - alpha value for cell state activation
// params[8] - beta value for cell state activation
// params[9] - activation ID for output (h)
// params[10] - alpha value for output activation
// params[11] - beta value for output activation
// INPUTS:
// x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr
// Wx - input weights [nIn, 4*nOut]
// Wr - recurrent weights [nOut, 4*nOut]
// b - biases [4*nOut], optional, may be nullptr
// hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr
// cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr
// Wp - peephole weights [3*nOut], optional, may be nullptr
// dLdh - loss derivative with respect to h, [bS, nOut] or [nOut] if seqLen != nullptr
// dLdc - loss derivative with respect to c, [bS, nOut] or [nOut] if seqLen != nullptr
// z - zi,zf,zg,zo taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut]
// a - i,f,g,o taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut]
// c - taken from ff outputs to reduce amount of calculations in bp, [bS, nOut]
// OUTPUTS:
// dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr
// dLdWx - loss derivative with respect to Wx, [nIn, 4*nOut]
// dLdWr - loss derivative with respect to Wr, [nOut, 4*nOut]
// dLdb - loss derivative with respect to b, optional, may be nullptr, [4*nOut]
// dLdhI - loss derivative with respect to hI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
// dLdcI - loss derivative with respect to cI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
// dLdWp - loss derivative with respect to Wp, optional, may be nullptr, [3*nOut]
// !!! dimension 4*nOut implies order i, f, g, o
// !!! dimension 3*nOut implies order i, f, o
// dhdc = o*tanhDeriv + Wp ? tanh(c)*dodzo*dzodc : 0 [bS, nOut]
// dcdcI = f + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0 [bS, nOut]
// dLdhI += dLdh; [bS, nOut]
// dLdcI += dLdhI * dhdc; [bS, nOut]
// dLdzi = dLdcI*dcdi*didzi; [bS, nOut](or[nOut])
// dLdzf = dLdcI*dcdf*dfdzf; [bS, nOut](or[nOut])
// dLdzg = dLdcI*dcdg*dgdzg; [bS, nOut](or[nOut])
// dLdzo = dLdhI*dhdo*dodzo; [bS, nOut](or[nOut])
// dLdx = dLdzi^WxiT + dLdzf^WxfT + dLdzg^WxgT + dLdzo^WxoT, [bS, nIn]
// dLdhI = dLdzi^WriT + dLdzf^WrfT + dLdzg^WrgT + dLdzo^WroT, [bS, nOut]
// dLdcI = dLdcI*dcdcI, [bS, nOut]
// dLdWxi = xT^dLdzi [nIn, bS] x [bS, nOut] = [nIn, nOut]
// dLdWxf = xT^dLdzf [nIn, bS] x [bS, nOut] = [nIn, nOut]
// dLdWxg = xT^dLdzg [nIn, bS] x [bS, nOut] = [nIn, nOut]
// dLdWxo = xT^dLdzo [nIn, bS] x [bS, nOut] = [nIn, nOut]
// dLdWri = hIT^dLdzi [nOut, bS] x [bS, nOut] = [nOut, nOut]
// dLdWrf = hIT^dLdzf [nOut, bS] x [bS, nOut] = [nOut, nOut]
// dLdWrg = hIT^dLdzg [nOut, bS] x [bS, nOut] = [nOut, nOut]
// dLdWro = hIT^dLdzo [nOut, bS] x [bS, nOut] = [nOut, nOut]
// dLdbi = dLdzi.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
// dLdbf = dLdzf.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
// dLdbg = dLdzg.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
// dLdbo = dLdzo.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
// dLdWpi = (dLdzi*cI).reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
// dLdWpf = (dLdzf*cI).reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
// dLdWpo = (dLdzo*c) .reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
const Nd4jLong nIn = x->sizeAt(-1);
NDArray zi = x->rankOf() == 1 ? (*z)({0, nOut}) : (*z)({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut])
NDArray zf = x->rankOf() == 1 ? (*z)({nOut, 2*nOut}) : (*z)({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut])
NDArray zg = x->rankOf() == 1 ? (*z)({2*nOut, 3*nOut}) : (*z)({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut])
NDArray zo = x->rankOf() == 1 ? (*z)({3*nOut, 4*nOut}) : (*z)({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut])
NDArray i = x->rankOf() == 1 ? (*a)({0, nOut}) : (*a)({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut])
NDArray f = x->rankOf() == 1 ? (*a)({nOut, 2*nOut}) : (*a)({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut])
NDArray g = x->rankOf() == 1 ? (*a)({2*nOut, 3*nOut}) : (*a)({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut])
NDArray o = x->rankOf() == 1 ? (*a)({3*nOut, 4*nOut}) : (*a)({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut])
NDArray dLdz = z->ulike(); // [bS, 4*nOut](or[4*nOut])
NDArray dLdzi = x->rankOf() == 1 ? dLdz({0, nOut}) : dLdz({0,0, 0, nOut});
NDArray dLdzf = x->rankOf() == 1 ? dLdz({nOut, 2*nOut}) : dLdz({0,0, nOut, 2*nOut});
NDArray dLdzg = x->rankOf() == 1 ? dLdz({2*nOut, 3*nOut}) : dLdz({0,0, 2*nOut, 3*nOut});
NDArray dLdzo = x->rankOf() == 1 ? dLdz({3*nOut, 4*nOut}) : dLdz({0,0, 3*nOut, 4*nOut});
// dcdzi = dcdi*didzi, [bS, nOut](or[nOut])
activationDeriv(zi, params[3], params[4], params[5], dLdzi); // didzi, inplace
dLdzi *= g; // dcdi = g*clipDeriv
// dcdzf = dcdf*dfdzf, [bS, nOut](or[nOut])
activationDeriv(zf, params[3], params[4], params[5], dLdzf); // dfdzf, inplace
dLdzf *= *cI; // dcdf = cI*clipDeriv
// dcdzg = dcde*dedzg, [bS, nOut](or[nOut])
activationDeriv(zg, params[6], params[7], params[8], dLdzg); // dgdzg, inplace
dLdzg *= i; // dcdf = i*clipDeriv
// dhdzo = dhdo*dodzo = actH(c)*dodzo, [bS, nOut](or[nOut])
activationDeriv(zo, params[3], params[4], params[5], dLdzo);
NDArray temp = dLdzo.ulike();
applyActivation(*c, params[9], params[10], params[11], temp); // actH(c), inplace
dLdzo *= temp;
// dcdcI
NDArray dcdcI = f.dup(); // dcdcI = f*clipDeriv [bS, nOut](or[nOut])
// take into account possible deposit from clipping derivative
clipDeriv(params[2], *c, dLdzi, dLdzf, dLdzg, dcdcI);
// dhdc
NDArray dhdc = c->ulike();
activationDeriv(*c, params[9], params[10], params[11], dhdc); // [bS, nOut]
dhdc *= o;
if(Wp) {
dhdc += dLdzo*(*Wp)({2*nOut, 3*nOut});
dcdcI += dLdzi*(*Wp)({0, nOut}) + dLdzf*(*Wp)({nOut, 2*nOut}); // broadcast [bS, nOut] * nOut + ...
}
if(dLdh)
*dLdhI += *dLdh;
if(dLdc)
*dLdcI += *dLdc;
else
*dLdcI += *dLdhI * dhdc;
dLdzi *= *dLdcI; // [bS, nOut](or[nOut])
dLdzf *= *dLdcI; // [bS, nOut](or[nOut])
dLdzg *= *dLdcI; // [bS, nOut](or[nOut])
dLdzo *= *dLdhI; // [bS, nOut](or[nOut])
// dLdx
NDArray WxT = Wx->transpose();
MmulHelper::mmul(&dLdz, &WxT, dLdx); // [bS, 4*nOut] x [4*nOut, nIn] (or [4*nOut] x [4*nOut, nIn]) = [bS, nIn] ( or[nIn] )
// dLdhI
NDArray WrT = Wr->transpose();
MmulHelper::mmul(&dLdz, &WrT, dLdhI); // [bS, 4*nOut] x [4*nOut, nOut] (or [4*nOut] x [4*nOut, nOut]) = [bS, nOut] ( or[nOut] )
// dLdcI
dLdcI->assign(*dLdcI*dcdcI); // [bS, nOut](or[nOut])
if(x->rankOf() == 1) {
NDArray xT = x->reshape(x->ordering(),{nIn, 1}); // [nIn] -> [nIn, 1]
NDArray hIT = hI->reshape(hI->ordering(),{nOut, 1}); // [nOut] -> [nOut, 1]
NDArray dLdzR = dLdz.reshape(dLdz.ordering(), {1, 4*nOut}); // [nOut] -> [1, 4*nOut]
// dLdWx
*dLdWx += mmul(xT, dLdzR); // [nIn, 1] x [1, 4*nOut] = [nIn, 4*nOut]
// dLdWr
*dLdWr += mmul(hIT, dLdzR); // [nOut, 1] x [1, 4*nOut] = [nOut, 4*nOut]
}
else {
// dLdWx
*dLdWx += mmul(x->transpose(), dLdz); // [nIn, bS] x [bS, 4*nOut] = [nIn, 4*nOut]
// dLdWr
*dLdWr += mmul(hI->transpose(), dLdz); // [nOut, bS] x [bS, 4*nOut] = [nOut, 4*nOut]
}
// dLdb
if(b && x->rankOf() == 1)
*dLdb += dLdz; // [4*nOut]
else if(b)
*dLdb += dLdz.reduceAlongDimension(reduce::Sum, {0}); // [bS, 4*nOut] -> reduce -> [4*nOut];
// dLdWp
if(Wp && x->rankOf() == 1) {
(*dLdWp)({ 0,nOut}) += std::move(dLdzi)*(*cI); // [nOut]
(*dLdWp)({ nOut,2*nOut}) += std::move(dLdzf)*(*cI); // [nOut]
(*dLdWp)({2*nOut,3*nOut}) += std::move(dLdzo)*(*c); // [nOut]
}
else if(Wp) {
NDArray temp(Wp->ordering(), {nOut}, Wp->dataType(), Wp->getContext());
(std::move(dLdzi)*(*cI)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut]
(*dLdWp)({0,nOut}) += temp;
(std::move(dLdzf)*(*cI)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut]
(*dLdWp)({nOut,2*nOut}) += temp;
(std::move(dLdzo)*(*c)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut]
(*dLdWp)({2*nOut,3*nOut}) += temp;
}
}
//////////////////////////////////////////////////////////////////////////
void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
const std::vector<float>& params,
const bool forward,
NDArray* h, NDArray* hL, NDArray* cL) {
// INPUTS:
// x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL],
// Wx - input weights [nIn, 4*nOut]
// Wr - recurrent weights [nOut, 4*nOut]
// b - biases [4*nOut], optional, may be nullptr
// seqLen - [bS], optional, may be nullptr
// hI - initial output [bS, nOut], optional, may be nullptr
// cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr
// Wp - peephole weights [3*nOut], optional, may be nullptr
// OUTPUTS:
// h - output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, sL], optional, may be nullptr
// hL - output at last step [bS, nOut], optional, may be nullptr
// cL - cell state at last step [bS, nOut], optional, may be nullptr
// params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta};
// dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL]
const int dataFormat = params[0];
const int directionMode = params[1];
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
const std::vector<Nd4jLong> shapeOut = {bS, nOut};
auto h0 = const_cast<NDArray*>(hI);
if(!hI) {
h0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());
h0->nullify();
}
auto c0 = const_cast<NDArray*>(cI);
if(!cI) {
c0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());
c0->nullify();
}
auto ct = cL;
if(!cL)
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
ct = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());
auto ht = hL;
if(!h && !hL)
ht = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());
// create sets of required (depends on seqLen presence) sub-arrays
std::vector<int> dims;
ResultSet *xSet(nullptr), *hSet(nullptr), *h0Set(nullptr), *c0Set(nullptr), *htSet(nullptr), *ctSet(nullptr);
if(!seqLen) {
dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0}); // points on bS and nIn/nOut axes
Shyrma temp (#131) * - specifying template instantiation for certain types in float16 and bloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing bfloat16 and float16 member functions template specialization Signed-off-by: Yurii <iuriish@yahoo.com> * - rewrite and overload array +-*/ scalar and scalar +-*/ arr in NDAray class Signed-off-by: Yurii <iuriish@yahoo.com> * - make corrections which have to do with and rvalue lvalue conversions Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantic in NDArray operators array +-/* array Signed-off-by: Yurii <iuriish@yahoo.com> * float16/bfloat16 tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more tweak Signed-off-by: raver119 <raver119@gmail.com> * - make float16 and bfloat16 to compile successfully on cuda Signed-off-by: Yurii <iuriish@yahoo.com> * - do not use resources of view-like arrays when move semantics is applied Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of pointers in signatures NDArray methods 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::dup method Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::reduceAlongDimension method Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyIndexReduce and applyTrueBroadcast methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyReduce3 and varianceAlongDimension methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tensorsAlongDimension and diagonal methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::allTensorsAlongDimension Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduceAlongDimension 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyPairwiseTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTrueBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyScalar and applyScalarArr Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::lambda methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduce3 methods 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of following NDArray methods: add/sub/mul/div row/column and fillAsTriangular Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tileToShape methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::isShapeSameStrict method Signed-off-by: Yurii <iuriish@yahoo.com> * minor corrections in tests Signed-off-by: Yurii <iuriish@yahoo.com> * - replace reduce op in batchnorm mkldnn Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit templates instantiations for operator+(NDArray&&. const scalar) Signed-off-by: Yurii <iuriish@yahoo.com> * - corrections of casts in float16/bfloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantics in following NDArray methods: transform, applyTrueBroadcast, transpose, reshape, permute Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of input array A duplicate in svd cuda op Signed-off-by: Yurii <iuriish@yahoo.com> * - avoid available bug in svd cuda API Signed-off-by: Yurii <iuriish@yahoo.com> * - add temporary global memory buffer in svd cuda when calcUV = false and m != n Signed-off-by: Yurii <iuriish@yahoo.com> * - remove test with blfoat16 type for betainC Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts after master has been merged in Signed-off-by: Yurii <iuriish@yahoo.com> * - changed type of affected input array in fused_batch_norm Signed-off-by: Yurii <iuriish@yahoo.com> * - add several explicit type castings Signed-off-by: Yurii <iuriish@yahoo.com> * - add ND4J_EXPORT to operators Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit template types in instantiations of template arithm operators of NDArray class Signed-off-by: Yurii <iuriish@yahoo.com> * - one more test fix Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>
2019-12-20 20:35:39 +01:00
xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn]
if(h)
Shyrma temp (#131) * - specifying template instantiation for certain types in float16 and bloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing bfloat16 and float16 member functions template specialization Signed-off-by: Yurii <iuriish@yahoo.com> * - rewrite and overload array +-*/ scalar and scalar +-*/ arr in NDAray class Signed-off-by: Yurii <iuriish@yahoo.com> * - make corrections which have to do with and rvalue lvalue conversions Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantic in NDArray operators array +-/* array Signed-off-by: Yurii <iuriish@yahoo.com> * float16/bfloat16 tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more tweak Signed-off-by: raver119 <raver119@gmail.com> * - make float16 and bfloat16 to compile successfully on cuda Signed-off-by: Yurii <iuriish@yahoo.com> * - do not use resources of view-like arrays when move semantics is applied Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of pointers in signatures NDArray methods 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::dup method Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::reduceAlongDimension method Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyIndexReduce and applyTrueBroadcast methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyReduce3 and varianceAlongDimension methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tensorsAlongDimension and diagonal methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::allTensorsAlongDimension Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduceAlongDimension 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyPairwiseTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTrueBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyScalar and applyScalarArr Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::lambda methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduce3 methods 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of following NDArray methods: add/sub/mul/div row/column and fillAsTriangular Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tileToShape methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::isShapeSameStrict method Signed-off-by: Yurii <iuriish@yahoo.com> * minor corrections in tests Signed-off-by: Yurii <iuriish@yahoo.com> * - replace reduce op in batchnorm mkldnn Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit templates instantiations for operator+(NDArray&&. const scalar) Signed-off-by: Yurii <iuriish@yahoo.com> * - corrections of casts in float16/bfloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantics in following NDArray methods: transform, applyTrueBroadcast, transpose, reshape, permute Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of input array A duplicate in svd cuda op Signed-off-by: Yurii <iuriish@yahoo.com> * - avoid available bug in svd cuda API Signed-off-by: Yurii <iuriish@yahoo.com> * - add temporary global memory buffer in svd cuda when calcUV = false and m != n Signed-off-by: Yurii <iuriish@yahoo.com> * - remove test with blfoat16 type for betainC Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts after master has been merged in Signed-off-by: Yurii <iuriish@yahoo.com> * - changed type of affected input array in fused_batch_norm Signed-off-by: Yurii <iuriish@yahoo.com> * - add several explicit type castings Signed-off-by: Yurii <iuriish@yahoo.com> * - add ND4J_EXPORT to operators Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit template types in instantiations of template arithm operators of NDArray class Signed-off-by: Yurii <iuriish@yahoo.com> * - one more test fix Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>
2019-12-20 20:35:39 +01:00
hSet = new ResultSet(h->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nOut]
}
else {
dims = dataFormat == 2 ? std::vector<int>({1}) : std::vector<int>({2}); // points on nIn/nOut axis
Shyrma temp (#131) * - specifying template instantiation for certain types in float16 and bloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing bfloat16 and float16 member functions template specialization Signed-off-by: Yurii <iuriish@yahoo.com> * - rewrite and overload array +-*/ scalar and scalar +-*/ arr in NDAray class Signed-off-by: Yurii <iuriish@yahoo.com> * - make corrections which have to do with and rvalue lvalue conversions Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantic in NDArray operators array +-/* array Signed-off-by: Yurii <iuriish@yahoo.com> * float16/bfloat16 tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more tweak Signed-off-by: raver119 <raver119@gmail.com> * - make float16 and bfloat16 to compile successfully on cuda Signed-off-by: Yurii <iuriish@yahoo.com> * - do not use resources of view-like arrays when move semantics is applied Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of pointers in signatures NDArray methods 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::dup method Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::reduceAlongDimension method Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyIndexReduce and applyTrueBroadcast methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyReduce3 and varianceAlongDimension methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tensorsAlongDimension and diagonal methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::allTensorsAlongDimension Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduceAlongDimension 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyPairwiseTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTrueBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyScalar and applyScalarArr Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::lambda methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduce3 methods 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of following NDArray methods: add/sub/mul/div row/column and fillAsTriangular Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tileToShape methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::isShapeSameStrict method Signed-off-by: Yurii <iuriish@yahoo.com> * minor corrections in tests Signed-off-by: Yurii <iuriish@yahoo.com> * - replace reduce op in batchnorm mkldnn Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit templates instantiations for operator+(NDArray&&. const scalar) Signed-off-by: Yurii <iuriish@yahoo.com> * - corrections of casts in float16/bfloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantics in following NDArray methods: transform, applyTrueBroadcast, transpose, reshape, permute Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of input array A duplicate in svd cuda op Signed-off-by: Yurii <iuriish@yahoo.com> * - avoid available bug in svd cuda API Signed-off-by: Yurii <iuriish@yahoo.com> * - add temporary global memory buffer in svd cuda when calcUV = false and m != n Signed-off-by: Yurii <iuriish@yahoo.com> * - remove test with blfoat16 type for betainC Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts after master has been merged in Signed-off-by: Yurii <iuriish@yahoo.com> * - changed type of affected input array in fused_batch_norm Signed-off-by: Yurii <iuriish@yahoo.com> * - add several explicit type castings Signed-off-by: Yurii <iuriish@yahoo.com> * - add ND4J_EXPORT to operators Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit template types in instantiations of template arithm operators of NDArray class Signed-off-by: Yurii <iuriish@yahoo.com> * - one more test fix Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>
2019-12-20 20:35:39 +01:00
xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn]
h0Set = new ResultSet(h0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
c0Set = new ResultSet(c0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
ctSet = new ResultSet(ct->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
if(h)
Shyrma temp (#131) * - specifying template instantiation for certain types in float16 and bloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing bfloat16 and float16 member functions template specialization Signed-off-by: Yurii <iuriish@yahoo.com> * - rewrite and overload array +-*/ scalar and scalar +-*/ arr in NDAray class Signed-off-by: Yurii <iuriish@yahoo.com> * - make corrections which have to do with and rvalue lvalue conversions Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantic in NDArray operators array +-/* array Signed-off-by: Yurii <iuriish@yahoo.com> * float16/bfloat16 tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more tweak Signed-off-by: raver119 <raver119@gmail.com> * - make float16 and bfloat16 to compile successfully on cuda Signed-off-by: Yurii <iuriish@yahoo.com> * - do not use resources of view-like arrays when move semantics is applied Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of pointers in signatures NDArray methods 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::dup method Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::reduceAlongDimension method Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyIndexReduce and applyTrueBroadcast methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyReduce3 and varianceAlongDimension methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tensorsAlongDimension and diagonal methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::allTensorsAlongDimension Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduceAlongDimension 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyPairwiseTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTrueBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyScalar and applyScalarArr Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::lambda methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduce3 methods 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of following NDArray methods: add/sub/mul/div row/column and fillAsTriangular Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tileToShape methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::isShapeSameStrict method Signed-off-by: Yurii <iuriish@yahoo.com> * minor corrections in tests Signed-off-by: Yurii <iuriish@yahoo.com> * - replace reduce op in batchnorm mkldnn Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit templates instantiations for operator+(NDArray&&. const scalar) Signed-off-by: Yurii <iuriish@yahoo.com> * - corrections of casts in float16/bfloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantics in following NDArray methods: transform, applyTrueBroadcast, transpose, reshape, permute Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of input array A duplicate in svd cuda op Signed-off-by: Yurii <iuriish@yahoo.com> * - avoid available bug in svd cuda API Signed-off-by: Yurii <iuriish@yahoo.com> * - add temporary global memory buffer in svd cuda when calcUV = false and m != n Signed-off-by: Yurii <iuriish@yahoo.com> * - remove test with blfoat16 type for betainC Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts after master has been merged in Signed-off-by: Yurii <iuriish@yahoo.com> * - changed type of affected input array in fused_batch_norm Signed-off-by: Yurii <iuriish@yahoo.com> * - add several explicit type castings Signed-off-by: Yurii <iuriish@yahoo.com> * - add ND4J_EXPORT to operators Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit template types in instantiations of template arithm operators of NDArray class Signed-off-by: Yurii <iuriish@yahoo.com> * - one more test fix Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>
2019-12-20 20:35:39 +01:00
hSet = new ResultSet(h->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut]
if(ht)
Shyrma temp (#131) * - specifying template instantiation for certain types in float16 and bloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing bfloat16 and float16 member functions template specialization Signed-off-by: Yurii <iuriish@yahoo.com> * - rewrite and overload array +-*/ scalar and scalar +-*/ arr in NDAray class Signed-off-by: Yurii <iuriish@yahoo.com> * - make corrections which have to do with and rvalue lvalue conversions Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantic in NDArray operators array +-/* array Signed-off-by: Yurii <iuriish@yahoo.com> * float16/bfloat16 tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more tweak Signed-off-by: raver119 <raver119@gmail.com> * - make float16 and bfloat16 to compile successfully on cuda Signed-off-by: Yurii <iuriish@yahoo.com> * - do not use resources of view-like arrays when move semantics is applied Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of pointers in signatures NDArray methods 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::dup method Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::reduceAlongDimension method Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyIndexReduce and applyTrueBroadcast methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyReduce3 and varianceAlongDimension methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tensorsAlongDimension and diagonal methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::allTensorsAlongDimension Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduceAlongDimension 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyPairwiseTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTrueBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyScalar and applyScalarArr Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::lambda methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduce3 methods 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of following NDArray methods: add/sub/mul/div row/column and fillAsTriangular Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tileToShape methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::isShapeSameStrict method Signed-off-by: Yurii <iuriish@yahoo.com> * minor corrections in tests Signed-off-by: Yurii <iuriish@yahoo.com> * - replace reduce op in batchnorm mkldnn Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit templates instantiations for operator+(NDArray&&. const scalar) Signed-off-by: Yurii <iuriish@yahoo.com> * - corrections of casts in float16/bfloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantics in following NDArray methods: transform, applyTrueBroadcast, transpose, reshape, permute Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of input array A duplicate in svd cuda op Signed-off-by: Yurii <iuriish@yahoo.com> * - avoid available bug in svd cuda API Signed-off-by: Yurii <iuriish@yahoo.com> * - add temporary global memory buffer in svd cuda when calcUV = false and m != n Signed-off-by: Yurii <iuriish@yahoo.com> * - remove test with blfoat16 type for betainC Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts after master has been merged in Signed-off-by: Yurii <iuriish@yahoo.com> * - changed type of affected input array in fused_batch_norm Signed-off-by: Yurii <iuriish@yahoo.com> * - add several explicit type castings Signed-off-by: Yurii <iuriish@yahoo.com> * - add ND4J_EXPORT to operators Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit template types in instantiations of template arithm operators of NDArray class Signed-off-by: Yurii <iuriish@yahoo.com> * - one more test fix Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>
2019-12-20 20:35:39 +01:00
htSet = new ResultSet(ht->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
}
// loops
if(forward) {
if(!seqLen) {
if(!h) { // seqLen and h are absent
lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step
for (Nd4jLong t = 1; t < sL; ++t)
lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps
}
else { // seqLen is absent and h is present
lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, hSet->at(0), ct); // first time step
for (Nd4jLong t = 1; t < sL; ++t)
lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t - 1), ct, Wp, params, hSet->at(t), ct); // rest time steps
if(hL)
hL->assign(hSet->at(sL - 1)); // assign last output to hL if it is not nullptr
}
}
else {
if(!h) { // seqLen is present and h is absent
for (Nd4jLong e = 0; e < bS; ++e) {
const int limit = seqLen->e<int>(e);
if(limit == 0) {
if(cL)
ctSet->at(e)->nullify();
if(hL)
htSet->at(e)->nullify();
continue;
}
auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e);
lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step
for (int t = 1; t < limit; ++t) {
ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps
}
}
}
else { // seqLen and h are present
for (Nd4jLong e = 0; e < bS; ++e) {
int limit = seqLen->e<int>(e);
if(limit == 0) {
tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range
if(cL)
ctSet->at(e)->nullify();
if(hL)
htSet->at(e)->nullify();
continue;
}
auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e);
lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step
for (int t = 1; t < limit; ++t) {
auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps
indPrev = indCurr;
}
if(hL)
htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if hL is not nullptr
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
if(limit != sL)
tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL)
}
}
}
}
else { // backward
if(!seqLen) {
if(!h) { // seqLen and h are absent
lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step
for (Nd4jLong t = sL - 2; t >= 0; --t)
lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps
}
else { // seqLen is absent and h is present
lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, hSet->at(sL - 1), ct); // first time step
for (Nd4jLong t = sL - 2; t >= 0; --t)
lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t + 1), ct, Wp, params, hSet->at(t), ct); // rest time steps
if(hL)
hL->assign(hSet->at(0)); // assign last output to hL if it is not nullptr
}
}
else if(directionMode == 1) { // only backward, no bidirectional mode
if(!h) { // h is absent and seqLen is present
for (Nd4jLong e = 0; e < bS; ++e) {
const int limit = seqLen->e<int>(e);
if(limit == 0) {
if(cL)
ctSet->at(e)->nullify();
if(hL)
htSet->at(e)->nullify();
continue;
}
auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e);
lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step
for (Nd4jLong t = sL - 2; t >= sL - limit; --t) {
ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps
}
}
}
else { // seqLen and h are present
for (Nd4jLong e = 0; e < bS; ++e) {
int limit = seqLen->e<int>(e);
if(limit == 0) {
tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range
if(cL)
ctSet->at(e)->nullify();
if(hL)
htSet->at(e)->nullify();
continue;
}
auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e);
lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step
for (Nd4jLong t = sL - 2; t >= sL - limit; --t) {
auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps
indPrev = indCurr;
}
if(hL)
htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
if(limit != sL)
tensorAlongTimeBatchDims(*h, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL)
}
}
}
else { // backward in bidirectional mode
if(!h) { // h is absent and seqLen is present
for (Nd4jLong e = 0; e < bS; ++e) {
const int limit = seqLen->e<int>(e);
if(limit == 0) {
if(cL)
ctSet->at(e)->nullify();
if(hL)
htSet->at(e)->nullify();
continue;
}
auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e);
lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step
for (int t = limit - 2; t >= 0; --t) {
ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps
}
}
}
else { // seqLen and h are present
for (Nd4jLong e = 0; e < bS; ++e) {
int limit = seqLen->e<int>(e);
if(limit == 0) {
tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range
if(cL)
ctSet->at(e)->nullify();
if(hL)
htSet->at(e)->nullify();
continue;
}
auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e);
lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step
for (int t = limit - 2; t >= 0; --t) {
auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps
indPrev = indCurr;
}
if(hL)
htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
if(limit != sL)
tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL)
}
}
}
}
delete xSet;
delete hSet;
delete h0Set;
delete c0Set;
delete htSet;
delete ctSet;
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
if(!hI)
delete h0;
if(!cI)
delete c0;
if(!cL)
delete ct;
if(!h && !hL)
delete ht;
}
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
//////////////////////////////////////////////////////////////////////////
void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
const NDArray* b, const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp,
const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL,
const std::vector<float>& params, const bool forward,
NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdb, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdWp) {
// INPUTS:
// x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL],
// Wx - input weights [nIn, 4*nOut]
// Wr - recurrent weights [nOut, 4*nOut]
// b - biases [4*nOut], optional, may be nullptr
// seqLen - [bS], optional, may be nullptr
// hI - initial output [bS, nOut], optional, may be nullptr
// cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr
// Wp - peephole weights [3*nOut], optional, may be nullptr
// dLdh - gradient vs. output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, sL], optional, may be nullptr
// dLdhL - gradient vs. output at last time step [bS, nOut], optional, may be nullptr
// dLdcL - gradient vs. cell state at last time step [bS, nOut], optional, may be nullptr
// OUTPUTS:
// dLdx - gradient vs. input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL]
// dLdWx - gradient vs. input weights [nIn, 4*nOut]
// dLdWr - gradient vs. recurrent weights [nOut, 4*nOut]
// dLdb - gradient vs. biases [4*nOut], optional, may be nullptr
// dLdhI - gradient vs. initial output [bS, nOut], optional, may be nullptr
// dLdcI - gradient vs. initial cell state at time t-1 [bS, nOut], optional, may be nullptr
// dLdWp - gradient vs. peephole weights [3*nOut], optional, may be nullptr
// params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta};
// dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL]
const int dataFormat = params[0];
const int directionMode = params[1];
const int sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
const int bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
const int nOut = Wx->sizeAt(-1) / 4;
auto dLdh0 = dLdhI;
if(!hI)
dLdh0 = new NDArray(x->ordering(), {bS, nOut}, x->dataType(), x->getContext()); // this constructor nullifies array automatically
auto dLdc0 = dLdcI;
if(!cI)
dLdc0 = new NDArray(x->ordering(), {bS, nOut}, x->dataType(), x->getContext()); // this constructor nullifies array automatically
NDArray z(x->ordering(), {sL, bS, 4*nOut}, x->dataType(), x->getContext());
NDArray a = z.ulike();
NDArray h(x->ordering(), {sL+1, bS, nOut}, x->dataType(), x->getContext());
NDArray c = h.ulike();
// create sets of required (depends on seqLen presence) sub-arrays
std::vector<int> dims;
ResultSet *xSet(nullptr), *dLdxSet(nullptr), *hSet(nullptr), *cSet(nullptr), *zSet(nullptr), *aSet(nullptr), *dLdhSet(nullptr),
*dLdh0Set(nullptr), *dLdc0Set(nullptr), *dLdhLSet(nullptr), *dLdcLSet(nullptr), *hISet(nullptr), *cISet(nullptr);
if(!seqLen) {
dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0}); // points on [bS, nIn/nOut]
xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn]
dLdxSet = new ResultSet(dLdx->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn]
hSet = new ResultSet(h.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, nOut]
cSet = new ResultSet(c.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, nOut]
zSet = new ResultSet(z.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, 4*nOut]
aSet = new ResultSet(a.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, 4*nOut]
if(dLdh)
dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nOut]
}
else {
dims = dataFormat == 2 ? std::vector<int>({1}) : std::vector<int>({2}); // points on nIn/nOut axis
xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn]
dLdxSet = new ResultSet(dLdx->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn]
hSet = new ResultSet(h.allTensorsAlongDimension({2})); // sub-arrays with shape [nOut]
cSet = new ResultSet(c.allTensorsAlongDimension({2})); // sub-arrays with shape [nOut]
zSet = new ResultSet(z.allTensorsAlongDimension({2})); // sub-arrays with shape [4*nOut]
aSet = new ResultSet(a.allTensorsAlongDimension({2})); // sub-arrays with shape [4*nOut]
if(hI)
hISet = new ResultSet(hI->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
if(cI)
cISet = new ResultSet(cI->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
dLdh0Set = new ResultSet(dLdh0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
dLdc0Set = new ResultSet(dLdc0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
if(dLdh)
dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut]
if(!dLdh && dLdhL)
dLdhLSet = new ResultSet(dLdhL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
if(!dLdh && !dLdhL)
dLdcLSet = new ResultSet(dLdcL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
}
// loops
if(forward) {
if(!seqLen) { // seqLen is absent
if(hI)
h({0,1, 0,0, 0,0}).assign(hI);
else
h({0,1, 0,0, 0,0}).nullify();
if(cI)
c({0,1, 0,0, 0,0}).assign(cI);
else
c({0,1, 0,0, 0,0}).nullify();
// ff
for (int t = 0; t < sL; ++t)
lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, params, zSet->at(t), aSet->at(t), hSet->at(t+1), cSet->at(t+1));
// bp
for (int t = sL-1; t >= 0; --t) {
const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : (t == sL-1 ? dLdhL : nullptr);
const NDArray* dLdcc = dLdhh ? nullptr : (t == sL-1 ? dLdcL : nullptr);
lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, dLdhh, dLdcc,
zSet->at(t), aSet->at(t), cSet->at(t+1), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp);
}
}
else { // seqLen is present
for (int e = 0; e < bS; ++e) {
const int limit = seqLen->e<int>(e);
if(limit == 0) {
tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range
continue;
}
if(hI)
h({0,1, e,e+1, 0,0}).assign(hISet->at(e));
else
h({0,1, e,e+1, 0,0}).nullify();
if(cI)
c({0,1, e,e+1, 0,0}).assign(cISet->at(e));
else
c({0,1, e,e+1, 0,0}).nullify();
// ff
for (int t = 0; t < limit; ++t)
lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, params,
zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e));
// bp
for (int t = limit-1; t >= 0; --t) {
const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == limit-1 && dLdhL ? dLdhLSet->at(e) : nullptr);
const NDArray* dLdcc = dLdhh ? nullptr : (t == limit-1 ? dLdcLSet->at(e) : nullptr);
lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, dLdhh, dLdcc,
zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at((t+1)*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr,
dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp);
}
if(limit != sL)
tensorAlongTimeBatchDims(*dLdx, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL)
}
}
}
else { // backward or bidirectional
if(!seqLen) { // backward or bidirectional, seqLen is absent
if(hI)
h({sL,sL+1, 0,0, 0,0}).assign(hI);
else
h({sL,sL+1, 0,0, 0,0}).nullify();
if(cI)
c({sL,sL+1, 0,0, 0,0}).assign(cI);
else
c({sL,sL+1, 0,0, 0,0}).nullify();
// ff
for (int t = sL-1; t >= 0; --t)
lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, params, zSet->at(t), aSet->at(t), hSet->at(t), cSet->at(t));
// bp
for (int t = 0; t < sL; ++t) {
const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : (t == 0 ? dLdhL : nullptr);
const NDArray* dLdcc = dLdhh ? nullptr : (t == 0 ? dLdcL : nullptr);
lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, dLdhh, dLdcc,
zSet->at(t), aSet->at(t), cSet->at(t), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp);
}
}
else if(directionMode == 1) { // backward, seqLen is present
for (int e = 0; e < bS; ++e) {
const int limit = seqLen->e<int>(e);
if(limit == 0) {
tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range
continue;
}
if(hI)
h({sL,sL+1, e,e+1, 0,0}).assign(hISet->at(e));
else
h({sL,sL+1, e,e+1, 0,0}).nullify();
if(cI)
c({sL,sL+1, e,e+1, 0,0}).assign(cISet->at(e));
else
c({sL,sL+1, e,e+1, 0,0}).nullify();
// ff
for (int t = sL - 1; t >= sL-limit; --t)
lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, params,
zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at(t*bS + e), cSet->at(t*bS + e));
// bp
for (int t = sL-limit; t < sL; ++t) {
const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == sL-limit && dLdhL ? dLdhLSet->at(e) : nullptr);
const NDArray* dLdcc = dLdhh ? nullptr : (t == sL-limit ? dLdcLSet->at(e) : nullptr);
lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdcc,
zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr,
dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp);
}
if(limit != sL)
tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL)
}
}
else { // bidirectional mode, seqLen is present
for (int e = 0; e < bS; ++e) {
const int limit = seqLen->e<int>(e);
if(limit == 0) {
tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range
continue;
}
if(hI)
h({limit,limit+1, e,e+1, 0,0}).assign(hISet->at(e));
else
h({limit,limit+1, e,e+1, 0,0}).nullify();
if(cI)
c({limit,limit+1, e,e+1, 0,0}).assign(cISet->at(e));
else
c({limit,limit+1, e,e+1, 0,0}).nullify();
// ff
for (int t = limit - 1; t >= 0; --t)
lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, params,
zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at(t*bS + e), cSet->at(t*bS + e));
// bp
for (int t = 0; t < limit; ++t) {
const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == 0 && dLdhL ? dLdhLSet->at(e) : nullptr);
const NDArray* dLdcc = dLdhh ? nullptr : (t == 0 ? dLdcLSet->at(e) : nullptr);
lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdcc,
zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr,
dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp);
}
if(limit != sL)
tensorAlongTimeBatchDims(*dLdx, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL)
}
}
}
delete xSet; delete dLdxSet; delete hSet; delete cSet; delete aSet; delete zSet;
delete dLdhSet; delete dLdh0Set; delete dLdc0Set; delete dLdhLSet; delete dLdcLSet; delete hISet; delete cISet;
if(!hI)
delete dLdh0;
if(!cI)
delete dLdc0;
}
}
}
}
Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>
2020-04-13 12:21:51 +02:00
//////////////////////////////////////////////////////////////////////////
// void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
// const NDArray* b, NDArray* hI, NDArray* cI, const NDArray* Wp, const NDArray* dLdh,
// const std::vector<float>& params, const bool firstIter,
// NDArray* dhIdcI, NDArray* dhIdWx, NDArray* dcIdWx, NDArray* dhIdWr, NDArray* dcIdWr,
// NDArray* dhIdb, NDArray* dcIdb, NDArray* dhIdWp, NDArray* dcIdWp,
// NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) {
// /************************ THIS IS NOT OPTIMAZED CODE ***********************************/
// /** the objective is to provide math-readable code **/
// // equations (no peephole connections)
// // zi = x ^ Wxi + hI ^ Wri + bi
// // zf = x ^ Wxf + hI ^ Wrf + bf
// // zg = x ^ Wxg + hI ^ Wrg + bg
// // zo = x ^ Wxo + hI ^ Wro + bo
// // i = act(zi)
// // f = act(zf)
// // g = actC(zg)
// // o = act(zo)
// // c = clip(f * cI + i * g)
// // h = o * actH(c)
// // equations (peephole connections are present)
// // zi = x ^ Wxi + hI ^ Wri + cI * Wpi + bi
// // zf = x ^ Wxf + hI ^ Wrf + cI * Wpf + bf
// // zg = x ^ Wxg + hI ^ Wrg + bg
// // zo = x ^ Wxo + hI ^ Wro + c * Wpo + bo
// // i = act(zi)
// // f = act(zf)
// // g = actC(zg)
// // o = act(zo)
// // c = clip(f * cI + i * g)
// // h = o * actH(c)
// // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
// // params[0] - dataFormat, ignore
// // params[1] - directionMode, ignore
// // params[2] - cell clipping value, if it = 0 then do not apply clipping
// // params[3] - activation ID for input (i), forget (f) and output (o) gates
// // params[4] - alpha value for gates activation
// // params[5] - beta value for gates activation
// // params[6] - activation ID for cell state (c)
// // params[7] - alpha value for cell state activation
// // params[8] - beta value for cell state activation
// // params[9] - activation ID for output (h)
// // params[10] - alpha value for output activation
// // params[11] - beta value for output activation
// // INPUTS:
// // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr
// // Wx - input weights [nIn, 4*nOut]
// // Wr - recurrent weights [nOut, 4*nOut]
// // b - biases [4*nOut], optional, may be nullptr
// // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr
// // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr
// // Wp - peephole weights [3*nOut], optional, may be nullptr
// // dLdh - loss derivative with respect to h, [bS, nOut] or [nOut] if seqLen != nullptr
// // dhIdcI - derivative from previous time step, [bS, nOut] or [nOut] if seqLen != nullptr
// // dhIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr
// // dcIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr
// // dhIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr
// // dcIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr
// // dcIdWp - derivative from previous time step, [3*nOut], optional, may be nullptr
// // dhIdWp - derivative from previous time step, [3*nOut], optional, may be nullptr
// // dcIdb - derivative from previous time step, [4*nOut], optional, may be nullptr
// // dhIdb - derivative from previous time step, [4*nOut], optional, may be nullptr
// // OUTPUTS:
// // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr
// // dLdWx - loss derivative with respect to Wx, [nIn, 4*nOut]
// // dLdWr - loss derivative with respect to Wr, [nOut, 4*nOut]
// // dLdb - loss derivative with respect to b, optional, may be nullptr, [4*nOut]
// // dLdhI - loss derivative with respect to hI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
// // dLdcI - loss derivative with respect to cI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
// // dLdWp - loss derivative with respect to Wp, optional, may be nullptr, [3*nOut]
// // !!! dimension 4*nOut implies order i, f, g, o
// // !!! dimension 3*nOut implies order i, f, o
// // dcdzi = dcdi*didzi
// // dcdzf = dcdf*dfdzf
// // dcdzg = dcdg*dgdzg
// // dhdzo = dhdo*dodzo
// // dhdc = dhdc + Wp ? dhdzo*dzodc : 0 [bS, nOut]
// // factor = dLdh*dhdc [bS, nOut]
// // iFactor = factor*dcdzi [bS, nOut]
// // fFactor = factor*dcdzf [bS, nOut]
// // eFactor = factor*dcdzg [bS, nOut]
// // oFactor = *dLdh*dhdzo [bS, nOut]
// // tempC = dcdcI + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0;
// // tempIFE = dcdzi^WriT + dcdzf^WrfT + dcdzg^WrgT
// // tempO = dhdzo^WroT
// // dhIdcI = dhdc_from_previous_time_step
// // dLdx = iFactor^WxiT + fFactor^WxfT + eFactor^WxgT + oFactor^WxoT, [bS, nIn]
// // dLdhI = iFactor^WriT + fFactor^WrfT + eFactor^WrgT + oFactor^WroT, [bS, nOut]
// // dLdcI = factor*tempC + dLdhI * dhIdcI, dhIdcI=0 if firstIter, [bS, nOut]
// // dcdWxi(dcIdWxi) = dcdzi*dzidWxi + tempIFE*dhIdWxi + tempC*dcIdWxi, dcIdWxi=dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut]
// // dcdWxf(dcIdWxf) = dcdzf*dzfdWxf + tempIFE*dhIdWxf + tempC*dcIdWxf, dcIdWxf=dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut]
// // dcdWxg(dcIdWxg) = dcdzg*dzgdWxg + tempIFE*dhIdWxg + tempC*dcIdWxg, dcIdWxg=dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut]
// // dcdWxo(dcIdWxo) = 0 + tempIFE*dhIdWxo + tempC*dcIdWxo; dcIdWxo=dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut]
// // dhdWxi(dhIdWxi) = 0 + dhdc*dcdWxi + tempO*dhIdWxi, dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut]
// // dhdWxf(dhIdWxf) = 0 + dhdc*dcdWxf + tempO*dhIdWxf, dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut]
// // dhdWxg(dhIdWxg) = 0 + dhdc*dcdWxg + tempO*dhIdWxg, dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut]
// // dhdWxo(dhIdWxo) = dhdzo*dzodWxo + dhdc*dcdWxo + tempO*dhIdWxo, dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut]
// // dhdWri(dhIdWri) = 0 + dhdc*dcdWri + tempO*dhIdWri, dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut]
// // dhdWrf(dhIdWrf) = 0 + dhdc*dcdWrf + tempO*dhIdWrf, dhIdWrf= 0 if firstIter, [nOut, nOut, bS, nOut]
// // dhdWrg(dhIdWrg) = 0 + dhdc*dcdWrg + tempO*dhIdWrg, dhIdWrg= 0 if firstIter, [nOut, nOut, bS, nOut]
// // dhdWro(dhIdWro) = dhdzo*dzodWro + dhdc*dcdWro + tempO*dhIdWro, dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut]
// // dcdWri(dcIdWri) = dcdzi*dzidWri + tempIFE*dhIdWri + tempC*dcIdWri, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut]
// // dcdWrf(dcIdWrf) = dcdzf*dzfdWrf + tempIFE*dhIdWrf + tempC*dcIdWrf, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut]
// // dcdWrg(dcIdWrg) = dcdzg*dzgdWrg + tempIFE*dhIdWrg + tempC*dcIdWrg, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut]
// // dcdWro(dcIdWro) = 0 + tempIFE*dhIdWro + tempC*dcIdWro; dcIdWro=dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut]
// // dcIdWpi = (dcdzi*cI + tempIFE*dhIdWpi + tempC*dcIdWpi).reduceALongFirstDim, dcIdWpi=dhIdWpi= 0 if firstIter, [bS, nOut]->reduce->[bS]
// // dcIdWpf = (dcdzf*cI + tempIFE*dhIdWpf + tempC*dcIdWpf).reduceALongFirstDim, dcIdWpf=dhIdWpf= 0 if firstIter, [bS, nOut]->reduce->[bS]
// // dcIdWpo = (0 + tempIFE*dhIdWpo + tempC*dcIdWpo).reduceALongFirstDim, dcIdWpo=dhIdWpo= 0 if firstIter, [bS, nOut]->reduce->[bS]
// // dhdWpi(dhIdWpi) =( 0 + dhdc*dcdWpi + tempO*dhIdWpi).reduceALongFirstDim, dhIdWpi= 0 if firstIter, [bS, nOut]->reduce->[bS]
// // dhdWpf(dhIdWpf) =( 0 + dhdc*dcdWpf + tempO*dhIdWpf).reduceALongFirstDim, dhIdWpf= 0 if firstIter, [bS, nOut]->reduce->[bS]
// // dhdWpo(dhIdWpo) =(dhdzo*c + dhdc*dcdWpo + tempO*dhIdWpo).reduceALongFirstDim, dhIdWpo= 0 if firstIter, [bS, nOut]->reduce->[bS]
// // dcdbi(dcIdbi) = (dcdzi + tempIFE*dhIdbi + tempC*dcIdbi).reduceALongFirstDim, dcIdbi=dhIdbi= 0 if firstIter, [bS, nOut]->reduce->[bS]
// // dcdbf(dcIdbf) = (dcdzf + tempIFE*dhIdbf + tempC*dcIdbf).reduceALongFirstDim, dcIdbf=dhIdbf= 0 if firstIter, [bS, nOut]->reduce->[bS]
// // dcdbg(dcIdbg) = (dcdzg + tempIFE*dhIdbg + tempC*dcIdbg).reduceALongFirstDim, dcIdbg=dhIdbg= 0 if firstIter, [bS, nOut]->reduce->[bS]
// // dcdbo(dcIdbo) = ( 0 + tempIFE*dhIdbo + tempC*dcIdbo).reduceALongFirstDim; dcIdbo=dhIdbo= 0 if firstIter, [bS, nOut]->reduce->[bS]
// // dhdbi(dhIdbi) = ( 0 + dhdc*dcdbi + tempO*dhIdbi).reduceALongFirstDim, dhIdbi= 0 if firstIter, [bS, nOut]->reduce->[bS]
// // dhdbf(dhIdbf) = ( 0 + dhdc*dcdbf + tempO*dhIdbf).reduceALongFirstDim, dhIdbf= 0 if firstIter, [bS, nOut]->reduce->[bS]
// // dhdbg(dhIdbg) = ( 0 + dhdc*dcdbg + tempO*dhIdbg).reduceALongFirstDim, dhIdbg= 0 if firstIter, [bS, nOut]->reduce->[bS]
// // dhdbo(dhIdbo) = (dhdzo + dhdc*dcdbo + tempO*dhIdbo).reduceALongFirstDim, dhIdbo= 0 if firstIter, [bS, nOut]->reduce->[bS]
// const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
// NDArray *Wpi(nullptr), *Wpf(nullptr), *Wpo(nullptr), *dcIdWpi(nullptr), *dcIdWpf(nullptr), *dcIdWpo(nullptr), *dhIdWpi(nullptr), *dhIdWpf(nullptr), *dhIdWpo(nullptr);
// if(Wp) {
// Wpi = new NDArray((*Wp)({0, nOut}));
// Wpf = new NDArray((*Wp)({nOut, 2*nOut}));
// Wpo = new NDArray((*Wp)({2*nOut, 3*nOut}));
// dhIdWpi = new NDArray((*dhIdWp)({0, nOut}));
// dhIdWpf = new NDArray((*dhIdWp)({nOut, 2*nOut}));
// dhIdWpo = new NDArray((*dhIdWp)({2*nOut, 3*nOut}));
// dcIdWpi = new NDArray((*dcIdWp)({0, nOut}));
// dcIdWpf = new NDArray((*dcIdWp)({nOut, 2*nOut}));
// dcIdWpo = new NDArray((*dcIdWp)({2*nOut, 3*nOut}));
// }
// NDArray *dcIdbi(nullptr), *dcIdbf(nullptr), *dcIdbg(nullptr), *dcIdbo(nullptr), *dhIdbi(nullptr), *dhIdbf(nullptr), *dhIdbg(nullptr), *dhIdbo(nullptr);
// if(b) {
// dhIdbi = new NDArray((*dhIdb)({0, nOut}));
// dhIdbf = new NDArray((*dhIdb)({nOut, 2*nOut}));
// dhIdbg = new NDArray((*dhIdb)({2*nOut, 3*nOut}));
// dhIdbo = new NDArray((*dhIdb)({3*nOut, 4*nOut}));
// dcIdbi = new NDArray((*dcIdb)({0, nOut}));
// dcIdbf = new NDArray((*dcIdb)({nOut, 2*nOut}));
// dcIdbg = new NDArray((*dcIdb)({2*nOut, 3*nOut}));
// dcIdbo = new NDArray((*dcIdb)({3*nOut, 4*nOut}));
// }
// NDArray dhIdWxi = x->rankOf() == 1 ? (*dhIdWx)({0,0, 0,nOut, 0,0}) : (*dhIdWx)({0,0, 0,nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
// NDArray dhIdWxf = x->rankOf() == 1 ? (*dhIdWx)({0,0, nOut,2*nOut, 0,0}) : (*dhIdWx)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
// NDArray dhIdWxg = x->rankOf() == 1 ? (*dhIdWx)({0,0, 2*nOut,3*nOut, 0,0}) : (*dhIdWx)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
// NDArray dhIdWxo = x->rankOf() == 1 ? (*dhIdWx)({0,0, 3*nOut,4*nOut, 0,0}) : (*dhIdWx)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
// NDArray dhIdWri = x->rankOf() == 1 ? (*dhIdWr)({0,0, 0,nOut, 0,0}) : (*dhIdWr)({0,0, 0,nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
// NDArray dhIdWrf = x->rankOf() == 1 ? (*dhIdWr)({0,0, nOut,2*nOut, 0,0}) : (*dhIdWr)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
// NDArray dhIdWrg = x->rankOf() == 1 ? (*dhIdWr)({0,0, 2*nOut,3*nOut, 0,0}) : (*dhIdWr)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
// NDArray dhIdWro = x->rankOf() == 1 ? (*dhIdWr)({0,0, 3*nOut,4*nOut, 0,0}) : (*dhIdWr)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
// NDArray dcIdWxi = x->rankOf() == 1 ? (*dcIdWx)({0,0, 0,nOut, 0,0}) : (*dcIdWx)({0,0, 0,nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
// NDArray dcIdWxf = x->rankOf() == 1 ? (*dcIdWx)({0,0, nOut,2*nOut, 0,0}) : (*dcIdWx)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
// NDArray dcIdWxg = x->rankOf() == 1 ? (*dcIdWx)({0,0, 2*nOut,3*nOut, 0,0}) : (*dcIdWx)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
// NDArray dcIdWxo = x->rankOf() == 1 ? (*dcIdWx)({0,0, 3*nOut,4*nOut, 0,0}) : (*dcIdWx)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
// NDArray dcIdWri = x->rankOf() == 1 ? (*dcIdWr)({0,0, 0,nOut, 0,0}) : (*dcIdWr)({0,0, 0,nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
// NDArray dcIdWrf = x->rankOf() == 1 ? (*dcIdWr)({0,0, nOut,2*nOut, 0,0}) : (*dcIdWr)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
// NDArray dcIdWrg = x->rankOf() == 1 ? (*dcIdWr)({0,0, 2*nOut,3*nOut, 0,0}) : (*dcIdWr)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
// NDArray dcIdWro = x->rankOf() == 1 ? (*dcIdWr)({0,0, 3*nOut,4*nOut, 0,0}) : (*dcIdWr)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
// NDArray WxiT = (*Wx)({0,0, 0, nOut}).transpose(); // [nOut, nIn]
// NDArray WxfT = (*Wx)({0,0, nOut, 2*nOut}).transpose(); // [nOut, nIn]
// NDArray WxgT = (*Wx)({0,0, 2*nOut,3*nOut}).transpose(); // [nOut, nIn]
// NDArray WxoT = (*Wx)({0,0, 3*nOut,4*nOut}).transpose(); // [nOut, nIn]
// NDArray WriT = (*Wr)({0,0, 0, nOut}).transpose(); // [nOut, nOut]
// NDArray WrfT = (*Wr)({0,0, nOut, 2*nOut}).transpose(); // [nOut, nOut]
// NDArray WrgT = (*Wr)({0,0, 2*nOut,3*nOut}).transpose(); // [nOut, nOut]
// NDArray WroT = (*Wr)({0,0, 3*nOut,4*nOut}).transpose(); // [nOut, nOut]
// // ***** feed forward step ***** //
// auto z = mmul(*x, *Wx) + mmul(*hI, *Wr); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut]
// //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut]
// // add biases if they are given
// if(b)
// z += *b; // broadcast [bS, 4*nOut] + [4*nOut] = [bS, 4*nOut](or[4*nOut])
// auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut])
// auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut])
// auto zg = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut])
// auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut])
// // peephole connections for input and forget gates
// if(Wp) {
// zi += *cI * *Wpi; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
// zf += *cI * *Wpf; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
// }
// NDArray i = zi.ulike(); // [bS, nOut]
// NDArray f = zf.ulike(); // [bS, nOut]
// NDArray g = zg.ulike(); // [bS, nOut]
// applyActivation(zi, params[3], params[4], params[5], i);
// applyActivation(zf, params[3], params[4], params[5], f);
// applyActivation(zg, params[6], params[7], params[8], g);
// NDArray c = f * *cI + i * g; // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut])
// // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation
// if(params[2] != 0)
// c.applyScalar(scalar::LstmClip, params[2], c);
// // peephole connections for output gate
// if(Wp)
// zo += c * *Wpo; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
// NDArray o = zo.ulike(); // [bS, nOut](or[nOut])
// applyActivation(zo, params[3], params[4], params[5], o);
// // ***** back prop step ***** //
// NDArray dWxJacobian = mmulJacobianWeightsDeriv(nOut, *x); // [nIn, nOut, bS, nOut] (or [nIn, nOut, nOut])
// NDArray dWrJacobian = mmulJacobianWeightsDeriv(nOut, *hI); // [nOut, nOut, bS, nOut] (or [nOut, nOut, nOut])
// // dodzo
// NDArray dodzo = zo.ulike(); // [bS, nOut](or[nOut])
// activationDeriv(zo, params[3], params[4], params[5], dodzo);
// // dhdzo = dhdo*dodzo = actH(c)*dodzo
// NDArray dhdzo = zo.ulike(); // [bS, nOut](or[nOut])
// applyActivation(c, params[9], params[10], params[11], dhdzo); // actH(c)
// hI->assign(o*dhdzo);
// dhdzo *= dodzo;
// // dcdzi = dcdi*didzi
// NDArray dcdzi = zi.ulike(); // [bS, nOut](or[nOut])
// activationDeriv(zi, params[3], params[4], params[5], dcdzi); // didzi
// dcdzi *= g; // dcdi = g*clipDeriv
// // dcdzf = dcdf*dfdzf
// NDArray dcdzf = zf.ulike(); // [bS, nOut](or[nOut])
// activationDeriv(zf, params[3], params[4], params[5], dcdzf); // dfdzf
// dcdzf *= *cI; // dcdf = cI*clipDeriv
// // dcdzg = dcde*dedzg
// NDArray dcdzg = zg.ulike(); // [bS, nOut](or[nOut])
// activationDeriv(zg, params[6], params[7], params[8], dcdzg); // dedzg
// dcdzg *= i; // dcdf = i*clipDeriv
// // dcdcI
// NDArray dcdcI = f.dup(); // [bS, nOut](or[nOut])
// // take into account possible deposit from clipping derivative
// clipDeriv(params[2], c, dcdzi, dcdzf, dcdzg, dcdcI);
// // dzodc
// NDArray* dzodc = Wpo; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication)
// // dzidcI
// NDArray* dzidcI = Wpi; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication)
// // dzfdcI
// NDArray* dzfdcI = Wpf; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication)
// // dhdc
// NDArray dhdc = c.ulike();
// activationDeriv(c, params[9], params[10], params[11], dhdc); // [bS, nOut]
// dhdc *= o;
// if(Wp)
// dhdc += dhdzo* *dzodc;
// NDArray factor = *dLdh * dhdc;
// NDArray iFactor = factor*dcdzi; // [bS, nOut](or[nOut])
// NDArray fFactor = factor*dcdzf; // [bS, nOut](or[nOut])
// NDArray eFactor = factor*dcdzg; // [bS, nOut](or[nOut])
// NDArray oFactor = *dLdh *dhdzo; // [bS, nOut](or[nOut])
// NDArray tempC = dcdcI;
// if(Wp)
// tempC += dcdzi*(*dzidcI) + dcdzf*(*dzfdcI);
// // dLdx
// dLdx->assign(mmul(iFactor, WxiT) + mmul(fFactor, WxfT) + mmul(eFactor, WxgT) + mmul(oFactor, WxoT)); // [bS, nIn](or[nOut])
// // NDArray temp = c.ulike();
// // applyActivation(c, params[9], params[10], params[11], temp); // actH(c)
// // dLdx->assign(mmul(o*(1-temp*temp)*g*i*(1-i), WxiT) + mmul(o*(1-temp*temp)*(*cI)*f*(1-f), WxfT) + mmul(o*(1-temp*temp)*i*g*(1-g), WxgT) + mmul(temp*o*(1-o), WxoT)); // [bS, nIn](or[nOut])
// // dLdhI
// NDArray* dLdhII = dLdhI;
// if(dLdcI && !dLdhI)
// dLdhII = new NDArray(dLdcI->ulike());
// dLdhII->assign(mmul(iFactor, WriT) + mmul(fFactor, WrfT) + mmul(eFactor, WrgT) + mmul(oFactor, WroT)); // [bS, nOut](or[nOut])
// if(firstIter) {
// // dLdcI
// if(dLdcI)
// dLdcI->assign(factor*tempC); // [bS, nOut](or[nOut])
// // dcIdWxi(dcdWxi)
// dcIdWxi.assign(dcdzi*dWxJacobian); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]);
// // dcIdWxf(dcdWxf)
// dcIdWxf.assign(dcdzf*dWxJacobian);
// // dcIdWxg(dcdWxg)
// dcIdWxg.assign(dcdzg*dWxJacobian);
// // dcIdWxo(dcdWxo) = 0
// dcIdWxo.nullify();
// // dhIdWxi
// dhIdWxi.assign(dhdc*dcIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]);
// // dhIdWxf
// dhIdWxf.assign(dhdc*dcIdWxf);
// // dhIdWxg
// dhIdWxg.assign(dhdc*dcIdWxg);
// // dhIdWxo
// dhIdWxo.assign(dhdzo*dWxJacobian /*+ 0 */);
// // dcIdWri(dcdWri)
// dcIdWri.assign(dcdzi*dWrJacobian); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);;
// // dcIdWrf(dcdWrf)
// dcIdWrf.assign(dcdzf*dWrJacobian);
// // dcIdWrg(dcdWrg)
// dcIdWrg.assign(dcdzg*dWrJacobian);
// // dcIdWro(dcdWro) = 0
// dcIdWro.nullify();
// // dhIdWri
// dhIdWri.assign(dhdc*dcIdWri); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]);
// // dhIdWrf
// dhIdWrf.assign(dhdc*dcIdWrf);
// // dhIdWrg
// dhIdWrg.assign(dhdc*dcIdWrg);
// // dhIdWro
// dhIdWro.assign(dhdzo*dWrJacobian /*+ 0 */);
// if(Wp && x->rankOf() == 1) {
// // dcIdWpi
// dcIdWpi->assign(dcdzi*(*cI)); // [nOut] * [nOut]
// // dcIdWpf
// dcIdWpf->assign(dcdzf*(*cI)); // [nOut] * [nOut]
// // dcIdWpo
// dcIdWpo->nullify(); // [nOut]
// // dhdWpi
// dhIdWpi->assign(dhdc*(*dcIdWpi)); // [nOut] * [nOut]
// // dhdWpf
// dhIdWpf->assign(dhdc*(*dcIdWpf)); // [nOut] * [nOut]
// // dhdWpo
// dhIdWpo->assign(dhdzo*c /* +0*/); // [nOut] * [nOut]
// }
// else if(Wp) {
// // dcIdWpi
// (dcdzi*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpi, {0}); // [bS, nOut]->reduce->[nOut]
// // dcIdWpf
// (dcdzf*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpf, {0}); // [bS, nOut]->reduce->[nOut]
// // dcIdWpo
// dcIdWpo->nullify(); // [nOut]
// // dhIdWpi
// (*dLdh*dhdc*(dcdzi*(*cI))).reduceAlongDimension(reduce::Sum, *dhIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// // dhIdWpf
// (*dLdh*dhdc*(dcdzf*(*cI))).reduceAlongDimension(reduce::Sum, *dhIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// // dhIdWpo
// (*dLdh*dhdzo*c /* +0*/).reduceAlongDimension(reduce::Sum, *dhIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// }
// if(b && x->rankOf() == 1) {
// // dcIdbi
// dcIdbi->assign(dcdzi); // [nOut]
// // dcIdbf
// dcIdbf->assign(dcdzf); // [nOut]
// // dcIdbg
// dcIdbg->assign(dcdzg); // [nOut]
// // dcIdbo
// dcIdbo->nullify(); // [nOut]
// //dhIdbi
// dhIdbi->assign(dhdc*(*dcIdbi)); // [nOut]
// //dhIdbf
// dhIdbf->assign(dhdc*(*dcIdbf)); // [nOut]
// //dhIdbg
// dhIdbg->assign(dhdc*(*dcIdbg)); // [nOut]
// //dhIdbo
// dhIdbo->assign(dhdzo); // [nOut]
// }
// else if(b) {
// // dcIdbi
// dcdzi.reduceAlongDimension(reduce::Sum, *dcIdbi, {0}); // [bS, nOut]->reduce->[nOut]
// // dcIdbf
// dcdzf.reduceAlongDimension(reduce::Sum, *dcIdbf, {0}); // [bS, nOut]->reduce->[nOut]
// // dcIdbg
// dcdzg.reduceAlongDimension(reduce::Sum, *dcIdbg, {0}); // [bS, nOut]->reduce->[nOut]
// // dcIdbo
// dcIdbo->nullify(); // [nOut]
// //dhIdbi
// (*dLdh*dhdc*dcdzi).reduceAlongDimension(reduce::Sum, *dhIdbi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// //dhIdbf
// (*dLdh*dhdc*dcdzf).reduceAlongDimension(reduce::Sum, *dhIdbf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// //dhIdbg
// (*dLdh*dhdc*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// //dhIdbo
// (*dLdh*dhdzo).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// }
// }
// else {
// NDArray tempIFE = mmul(dcdzi, WriT) + mmul(dcdzf, WrfT) + mmul(dcdzg, WrgT);
// NDArray tempO = mmul(dhdzo, WroT);
// // dLdcI
// if(dLdcI)
// dLdcI->assign(factor*tempC + (*dLdhII)*(*dhIdcI));
// // dcIdWxi(dcdWxi)
// dcIdWxi.assign(dcdzi*dWxJacobian + tempIFE*dhIdWxi + tempC*dcIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);
// // dcIdWxf(dcdWxf)
// dcIdWxf.assign(dcdzf*dWxJacobian + tempIFE*dhIdWxf + tempC*dcIdWxf);
// // dcIdWxg(dcdWxg)
// dcIdWxg.assign(dcdzg*dWxJacobian + tempIFE*dhIdWxg + tempC*dcIdWxg);
// // dcIdWxo(dcdWxo)
// dcIdWxo.assign(/* 0 + */tempIFE * dhIdWxo + tempC*dcIdWxo);
// // dhIdWxi
// dhIdWxi.assign(dhdc*dcIdWxi + tempO*dhIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);
// // dhIdWxf
// dhIdWxf.assign(dhdc*dcIdWxf + tempO*dhIdWxf);
// // dhIdWxg
// dhIdWxg.assign(dhdc*dcIdWxg + tempO*dhIdWxg);
// // dhIdWxo
// dhIdWxo.assign(dhdzo*dWxJacobian + dhdc*dcIdWxo + tempO*dhIdWxo);
// // dcIdWri(dcdWri)
// dcIdWri.assign(dcdzi*dWrJacobian + tempIFE*dhIdWri + tempC*dcIdWri); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);
// // dcIdWrf(dcdWrf)
// dcIdWrf.assign(dcdzf*dWrJacobian + tempIFE*dhIdWrf + tempC*dcIdWrf);
// // dcIdWrg(dcdWrg)
// dcIdWrg.assign(dcdzg*dWrJacobian + tempIFE*dhIdWrg + tempC*dcIdWrg);
// // dcIdWro(dcdWro)
// dcIdWro.assign(/* 0 + */tempIFE * dhIdWro + tempC*dcIdWro);
// // dhIdWri
// dhIdWri.assign(dhdc*dcIdWri + tempO*dhIdWri); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);
// // dhIdWrf
// dhIdWrf.assign(dhdc*dcIdWrf + tempO*dhIdWrf);
// // dhIdWrg
// dhIdWrg.assign(dhdc*dcIdWrg + tempO*dhIdWrg);
// // dhIdWro
// dhIdWro.assign(dhdzo*dWrJacobian + dhdc*dcIdWro + tempO*dhIdWro);
// if(Wp && x->rankOf() == 1) {
// // dcIdWpi
// dcIdWpi->assign(dcdzi*(*cI) + tempIFE*(*dhIdWpi) + tempC*(*dcIdWpi)); // [nOut] * [nOut]
// // dcIdWpf
// dcIdWpf->assign(dcdzf*(*cI) + tempIFE*(*dhIdWpf) + tempC*(*dcIdWpf)); // [nOut] * [nOut]
// // dcIdWpo
// dcIdWpo->assign(/* 0 + */ tempIFE*(*dhIdWpo) + tempC*(*dcIdWpo)); // [nOut] * [nOut]
// // dhdWpi
// dhIdWpi->assign(dhdc*(*dcIdWpi) + tempO*(*dhIdWpi)); // [nOut] * [nOut]
// // dhdWpf
// dhIdWpf->assign(dhdc*(*dcIdWpf) + tempO*(*dhIdWpf)); // [nOut] * [nOut]
// // dhdWpo
// dhIdWpo->assign(dhdzo*c + dhdc*(*dcIdWpo) + tempO*(*dhIdWpo)); // [nOut] * [nOut]
// }
// else if(Wp) {
// // dcIdWpi
// (dcdzi*(*cI) + tempIFE*(*dhIdWpi) + tempC*(*dcIdWpi)).reduceAlongDimension(reduce::Sum, *dcIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// // dcIdWpf
// (dcdzf*(*cI) + tempIFE*(*dhIdWpf) + tempC*(*dcIdWpf)).reduceAlongDimension(reduce::Sum, *dcIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// // dcIdWpo
// (/* 0 + */ tempIFE*(*dhIdWpo) + tempC*(*dcIdWpo)).reduceAlongDimension(reduce::Sum, *dcIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// // dhIdWpi
// (dhdc*(*dcIdWpi) + tempO*(*dhIdWpi)).reduceAlongDimension(reduce::Sum, *dhIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// // dhIdWpf
// (dhdc*(*dcIdWpf) + tempO*(*dhIdWpf)).reduceAlongDimension(reduce::Sum, *dhIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// // dhIdWpo
// (dhdzo*c + dhdc*(*dcIdWpo) + tempO*(*dhIdWpo)).reduceAlongDimension(reduce::Sum, *dhIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// }
// if(b && x->rankOf() == 1) {
// // dcIdbi
// dcIdbi->assign(dcdzi + tempIFE*(*dhIdbi) + tempC*(*dcIdbi)); // [nOut]
// // dcIdbf
// dcIdbf->assign(dcdzf + tempIFE*(*dhIdbf) + tempC*(*dcIdbf)); // [nOut]
// // dcIdbg
// dcIdbg->assign(dcdzg + tempIFE*(*dhIdbg) + tempC*(*dcIdbg)); // [nOut]
// // dcIdbo
// dcIdbo->assign(/*0+*/ tempIFE*(*dhIdbo) + tempC*(*dcIdbo)); // [nOut]
// //dhIdbi
// dhIdbi->assign(dhdc*(*dcIdbi) + tempO*(*dhIdbi)); // [nOut]
// //dhIdbf
// dhIdbf->assign(dhdc*(*dcIdbf) + tempO*(*dhIdbf)); // [nOut]
// //dhIdbg
// dhIdbg->assign(dhdc*(*dcIdbg) + tempO*(*dhIdbg)); // [nOut]
// //dhIdbo
// dhIdbo->assign(dhdzo + dhdc*(*dcIdbo) + tempO*(*dhIdbo)); // [nOut]
// }
// else if(b) {
// // dcIdbi
// (dcdzi + tempIFE*(*dhIdbi) + tempC*(*dcIdbi)).reduceAlongDimension(reduce::Sum, *dcIdbi, {0}); // [bS, nOut]->reduce->[nOut]
// // dcIdbf
// (dcdzf + tempIFE*(*dhIdbf) + tempC*(*dcIdbf)).reduceAlongDimension(reduce::Sum, *dcIdbf, {0}); // [bS, nOut]->reduce->[nOut]
// // dcIdbg
// (dcdzg + tempIFE*(*dhIdbg) + tempC*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dcIdbg, {0}); // [bS, nOut]->reduce->[nOut]
// // dcIdbo
// (/*0+*/ tempIFE*(*dhIdbo) + tempC*(*dcIdbo)).reduceAlongDimension(reduce::Sum, *dcIdbo, {0}); // [bS, nOut]->reduce->[nOut]
// //dhIdbi
// (dhdc*(*dcIdbi) + tempO*(*dhIdbi)).reduceAlongDimension(reduce::Sum, *dhIdbi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// //dhIdbf
// (dhdc*(*dcIdbf) + tempO*(*dhIdbf)).reduceAlongDimension(reduce::Sum, *dhIdbf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// //dhIdbg
// (dhdc*(*dcIdbg) + tempO*(*dhIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// //dhIdbo
// (dhdzo + dhdc*(*dcIdbo) + tempO*(*dhIdbo)).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// }
// }
// const std::vector<int> dimsToExclude = x->rankOf() == 1 ? std::vector<int>({2}) : std::vector<int>({2, 3});
// // dLdWxi, dLdWxf, dLdWxg, dLdWxo
// (*dLdh*(*dhIdWx)).reduceAlongDimension(reduce::Sum, *dLdWx, dimsToExclude);
// // dLdWri, dLdWrf, dLdWrg, dLdWro
// (*dLdh*(*dhIdWr)).reduceAlongDimension(reduce::Sum, *dLdWr, dimsToExclude);
// // dLdWpi, dLdWpf, dLdWpo
// if(Wp) {
// if(x->rankOf() == 1) {
// (*dLdWp)({0, nOut}).assign(*dLdh*(*dhIdWpi)); // [nOut] * [nOut]
// (*dLdWp)({nOut, 2*nOut}).assign(*dLdh*(*dhIdWpf)); // [nOut] * [nOut]
// (*dLdWp)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdWpo)); // [nOut] * [nOut]
// }
// else {
// // NDArray temp1 = (*dLdWp)({0, nOut});
// // NDArray temp2 = (*dLdWp)({nOut, 2*nOut});
// // NDArray temp3 = (*dLdWp)({2*nOut, 3*nOut});
// // dhIdWpi->reduceAlongDimension(reduce::Sum, temp1, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// // dhIdWpf->reduceAlongDimension(reduce::Sum, temp2, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// // dhIdWpo->reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// (*dLdWp)({0, nOut}).assign(dhIdWpi);
// (*dLdWp)({nOut, 2*nOut}).assign(dhIdWpf);
// (*dLdWp)({2*nOut, 3*nOut}).assign(dhIdWpo);
// }
// }
// // dLdbi, dLdbf, dLdbg, dLdbo
// if(b) {
// if(x->rankOf() == 1) {
// (*dLdb)({0, nOut}).assign(*dLdh*(*dhIdbi)); // [nOut] * [nOut]
// (*dLdb)({nOut, 2*nOut}).assign(*dLdh*(*dhIdbf)); // [nOut] * [nOut]
// (*dLdb)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdbg)); // [nOut] * [nOut]
// (*dLdb)({3*nOut, 4*nOut}).assign(*dLdh*(*dhIdbo)); // [nOut] * [nOut]
// }
// else {
// // NDArray temp1 = (*dLdb)({0, nOut});
// // NDArray temp2 = (*dLdb)({nOut, 2*nOut});
// // NDArray temp3 = (*dLdb)({2*nOut, 3*nOut});
// // NDArray temp4 = (*dLdb)({3*nOut, 4*nOut});
// // (*dLdh*(*dhIdbi)).reduceAlongDimension(reduce::Sum, temp1, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// // (*dLdh*(*dhIdbf)).reduceAlongDimension(reduce::Sum, temp2, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// // (*dLdh*(*dhIdbg)).reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// // (*dLdh*(*dhIdbo)).reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
// (*dLdb)({0, nOut}).assign(dhIdbi);
// (*dLdb)({nOut, 2*nOut}).assign(dhIdbf);
// (*dLdb)({2*nOut, 3*nOut}).assign(dhIdbg);
// (*dLdb)({3*nOut, 4*nOut}).assign(dhIdbo);
// }
// }
// //dhIdcI
// if(dLdcI)
// dhIdcI->assign(dhdc);
// cI->assign(c);
// if(dLdcI && !dLdhI)
// delete dLdhII;
// if(Wp) {
// delete Wpi; delete Wpf; delete Wpo; delete dcIdWpi; delete dcIdWpf; delete dcIdWpo; delete dhIdWpi; delete dhIdWpf; delete dhIdWpo;
// }
// if(b) {
// delete dcIdbi; delete dcIdbf; delete dcIdbg; delete dcIdbo; delete dhIdbi; delete dhIdbf; delete dhIdbg; delete dhIdbo;
// }
// }