Alex Black 1170827c18 Merge master to upstream (#7945)
* Shugeo strided slice zeros (#14)

* Modified strided_slice op to properly work with empty-like shapes.

* Fixed test for reduce_mean with empty-like input.

* [WIP] Last merge (#15)

* correct logsoftmax looss (#2)

* Small SameDiff listener fix (#4)

* Various fixes (#6)

* #7839 Fix for asXMatrix and tests

* #7866 EmbeddingSequenceLayer dtype fix + test

* #7856 SameDiff save/load stream methods

* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration

* EvaluationBinary 3d/4d

* More evaluation 3d/4d tests

* #7847 Evaluation empty checks

* Small test ifx

* #7848 Fix median edge case

* Improve DL4J samediff layer tests

* [WIP] FastText wrapper implemented (#8)

* FastText implemented

* Some fixes

* Fix shapes for wordsNearest

* Validation of input vectors

* Fixes

* Fixed test

* Thread tagged

* Some tweaks

* setContextClassLoader for DeallocatorServiceThread

* Numpy format tests (#1)

* Various fixes (#11)

* #7852 SameDiff gather fix

* #7892 SameDiff placeholder to constant conversion

* #7890 validate input rank for MLN/CG init methods

* Fix broken permute shape calculation

* Permute and gather fixes

* Tests

* #7850 LogSumExp fix + test

* Handful of test fixes

* Empty arrays with non-scalar shapes (#10)

* minor rearrangements for lambdas

* empty tensors with non-scalar shapes

* numpy empty tensors with non-scalar shapes

* few more empty tweaks

* Small fixes

* conv3d signature update

* micro fix in batchnorm mkldnn

* Import fixes

* Fix

* MKL-DNN update

* Small fill fix

* fill with empty input + test

* Fixes

* Small error improvement

* Fix

* one special test

* couple of fixes for lstm

* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone

* Fixes

* FP16

* Unsigned

* BFloat16

* Fill op - empty tweaks

* - couple of fixes for empty arrays construction
- stack updated

* strided slice fix

* one transform test

* provide method for reducing shapeInfo in case of input array is empty

* Fixed reduceAlongDimensions to use empty input properly.

* couple of broadcast tests

* couple of tests broadcast tests + tweak to make them pass

* add check of non-empty to methods producing sub-arrays

* Fixed reshapeC with zeros in shape.

* complete empty check in reduce_... legacy ops

* Concat and cumsum/prod

* Tweak to empty shape inference on import

* add empty check to the rest of reduce legacy ops

* one more test

* correct typo in evalReduceShapeInfoEmpty

* Added tests for reduce_* ops to tests with zero shapes.

* few more tests for empty reductions

* Fixed strided_slice op with empty case and tests.

* one more empty reduction test

* Fixed strided_slice test.

* add empty check to NDArray::reshapei

* infOrMax

* empty min/max with infinity tests

* made unstack working correctly with empty arrays

* few IndexReduce tests + tweaks for empty shapes

* add test for empty concat

* few tests fixed

* Validation fix for reductions on empty shapes

* Reverse fix

* Reduction shape calc fixes

* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs

* Range fix

* - NDArray constructor updated for scalars/empty arrays
- few tests fixed

* More fixes

* Empty creator fixes

* concat fix

* concat fix

* TF import tests: allow 'both all NaN' and 'both all inf' to pass

* Slice, zero fraction, and reshape fixes

* transpose, gather

* Zero fraction

* scalar cast fix

* Empty reduction axis support

* few more tests fixed

* Fixed input checks conforming with TF for concat op and tests.

* few tests fixed

* matmul scalar shape fix

* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.

* broadcast bool fix

* few more tests

* few more tests

* correct evalReduceShapeInfoEmpty

* argmax/argmin + tests

* one more empty edge case + one more test

* argmax/argmin/realdiv_bp tweaks

* empty reshape test + fix

* Helper fixes

* Small fixes

* Gather test fix

* Gather test fix

* Small fixes

* reduce scalar zero values

* scalar mean workaround

* Remove debug code

* along dim mean workaround

* one more test

* - equalsTo() tweak for empty arrays
- one more test

* broadcast tweaks

* [WIP] Fixing outstanding issues for NLP (#9)

* Avoid using not-inited objects

* Test fixed.

* Redundant method avoided for models like FastText

* KMeans++ implementation

* KMeans++ implementation

* Disable parallel execution

* KMeans++

* Tests

* Dev branch merge (#16)

* SameDiff: convertDataType and gradient check util improvements (#12)

* GradCheck util improvements

* StopGradient constructor + test

* SameDiff: Add datatype conversion

* Javadoc and add DataType.isNumerical()

* Small fix

* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)

* TFGraphTestAllHelper: check intermediates in execution order

* Add missing debug listener

* [WIP] lstmBlock fix + other changes (#13)

- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite

* Small test fix

* CheckNumerics op wrapper

* Fix some issues on master (#17)

* Fix DataVec test issue

* Fix issue with dl4j SameDiff output layer

* Dtype fix for lambda layers

* #7912 BertIterator dtype fix (use float32 not global default)

* [WIP] Next set of CUDA stuff (#7)

New CUDA implementations and improvements

* bad file

* Dev branch master merge (#23)

* SameDiff: convertDataType and gradient check util improvements (#12)

* GradCheck util improvements

* StopGradient constructor + test

* SameDiff: Add datatype conversion

* Javadoc and add DataType.isNumerical()

* Small fix

* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)

* TFGraphTestAllHelper: check intermediates in execution order

* Add missing debug listener

* [WIP] lstmBlock fix + other changes (#13)

- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite

* Small test fix

* CheckNumerics op wrapper

* Compatibility of deserialization (#18)

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* SameDiff: add activation gradient checking support for debugging (#19)

* SameDiff gradient checker: first pass on activation gradient checks

* Fixes + tests for activation gradient checking

* Javadoc

* [WIP] Some nd4j data type corrections (#20)

* Adjust data type

* Set correct Data type.

* Size of proper data type.

* fix averaged cpu load (#22)

* SameDiff ops, TF import and fixes (#24)

* CheckNumerics tests + fixes + misc fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fake quant

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* FakeQuantWithMinMaxArgs

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* CheckNumerics fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Javadoc

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Exception tweak

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix for out of scope stack allocated var use

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Ignores

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Ignore for known failing test (already logged issue)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Merge upstream to fork (#25)

* Add thousand-separator commas to TotalParams (#7915)

* Add thousand-separator commas to TotalParams

The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them.

* Add thousand-separator commas to MultiLayerNetwork

Corresponding change to MultiLayerNetwork

Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com>

* Update contributing and issue/PR templates (#7934)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix link to AdaDelta paper (#7942)

Fix link to AdaDelta paper hosted on matthewzeiler.com

Signed-off-by: Jxtps

* Fixes, and ignores for known/logged failing issues (#7943)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* SameDiff + DL4J/SameDiff: Multiple fixes (#28)

* #7919 HDF5 attribute buffer length fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7909 Arbiter constructor exception ux improvements

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7925 RNN output layer length checks

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7939 Add listener for validating inputs are not incorrectly modified

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7939 Integrate NonInplaceValidationListener into tests

* #7844 DL4J SameDiff fixes for variable minibatch size

* DL4J SameDiff fixes - ensure gradient for input placeholder is available

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Tweaks to ExternalErrorsFunction - use placeholders, make more robust

* Another fix

* More fixes

* More SameDiff/DL4J fixes

* Scope out scalar array creation in BaseScalarOp

* Remove debug code

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* [WIP] Final dev branch merge (#29)

* SameDiff: convertDataType and gradient check util improvements (#12)

* GradCheck util improvements

* StopGradient constructor + test

* SameDiff: Add datatype conversion

* Javadoc and add DataType.isNumerical()

* Small fix

* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)

* TFGraphTestAllHelper: check intermediates in execution order

* Add missing debug listener

* [WIP] lstmBlock fix + other changes (#13)

- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite

* Small test fix

* CheckNumerics op wrapper

* Compatibility of deserialization (#18)

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* SameDiff: add activation gradient checking support for debugging (#19)

* SameDiff gradient checker: first pass on activation gradient checks

* Fixes + tests for activation gradient checking

* Javadoc

* [WIP] Some nd4j data type corrections (#20)

* Adjust data type

* Set correct Data type.

* Size of proper data type.

* fix averaged cpu load (#22)

* [WIP] Multiple dataset iterators (#27)

* Splitting dataset into arbitrary number

* Fixes

* Multiple split of iterator

* Test

* Test

* Some fixes

* signature change

* one more tweak

Signed-off-by: raver119 <raver119@gmail.com>

* one more test for sequential use of DataSetIteratorSplitter

Signed-off-by: raver119 <raver119@gmail.com>

* Fixes

* Fixes

* one more test for Alexander

Signed-off-by: raver119 <raver119@gmail.com>

* Some fixes

* Some fixes

* one more test for Alexander

Signed-off-by: raver119 <raver119@gmail.com>

* minor test fix

Signed-off-by: raver119 <raver119@gmail.com>

* Some fixes

* Some fixes

* couple of assertions tweaked

Signed-off-by: raver119 <raver119@gmail.com>

* MDS splitter test :/

Signed-off-by: raver119 <raver119@gmail.com>

* Minor refactoring

* Multi dataset

* Some fixes

* More tests

* Small number of test fixes/improvements (failures on CI) (#31)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* [WIP] More CUDA stuff (#26)

* initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* LRN BP CUDA

Signed-off-by: raver119 <raver119@gmail.com>

* less memory

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed bug with crop_and_resize op helper.

* get rid of unnecessary index-calculation dunction

Signed-off-by: Yurii <yurii@skymind.io>

* Fixed sort with nth_element cuda-based helper.

* Refactored nth_element.

* Refactored nth_element op and tests.

* Modified usage of dim array with sortTad routine.

* Refactored main routine of helper for non_max_image_suppression op.

* non_max_image_suppression op helper with cuda kernel implementation. Initial revision.

* fix vol2col cuda kernel

* meh

Signed-off-by: raver119 <raver119@gmail.com>

* topK concept

Signed-off-by: raver119 <raver119@gmail.com>

* unsorted topK with scanWitdh of 1

Signed-off-by: raver119 <raver119@gmail.com>

* correct vol2col tests

* sorted/unsorted topK

Signed-off-by: raver119 <raver119@gmail.com>

* implementation and fixing col2im/col2vol

* Corrected usage flags with input/output with reverse op.

* dup is const now

Signed-off-by: raver119 <raver119@gmail.com>

* percentile op

Signed-off-by: raver119 <raver119@gmail.com>

* group tests for mapool2d

Signed-off-by: Yurii <yurii@skymind.io>

* special test for george

Signed-off-by: raver119 <raver119@gmail.com>

* less threads for sortTad

Signed-off-by: raver119 <raver119@gmail.com>

* provide conv2d for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* remove auther in sort tad kernel code

Signed-off-by: Yurii <yurii@skymind.io>

* provide depthwise_conv2d for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* - max_pooling_with_argmax
- null check for special use

Signed-off-by: raver119 <raver119@gmail.com>

* dts cuda

Signed-off-by: raver119 <raver119@gmail.com>

* provide sconv2d for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* std cuda

Signed-off-by: raver119 <raver119@gmail.com>

* Refactored non_max_suppression op to conform TF implementation.

* Improved suppression helper.

* provide pooling3d for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* minor lstm rearrangements

Signed-off-by: raver119 <raver119@gmail.com>

* more of minor lstm rearrangements

Signed-off-by: raver119 <raver119@gmail.com>

* (bi)dynamic_rnn

Signed-off-by: raver119 <raver119@gmail.com>

* templates init order

Signed-off-by: raver119 <raver119@gmail.com>

* Refactored non_max_suppression op.

* Added cuda kernel for non_max_suppression.

* CPU sort by key/value

Signed-off-by: raver119 <raver119@gmail.com>

* CPU sort TAD by key/value

Signed-off-by: raver119 <raver119@gmail.com>

* CPU sort TAD by key/value tests

Signed-off-by: raver119 <raver119@gmail.com>

* Eliminate compiler error with cuda implementation.

* - repaired gradCheck in cuda
- provide conv2d_bp for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* missed signature

Signed-off-by: raver119 <raver119@gmail.com>

* provide depthwise_conv2d_bp for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* Implementation of lup helper with cuda kernel. Initial commit.

* further work on backprops for convolutions

Signed-off-by: Yurii <yurii@skymind.io>

* CUDA linear sort by key/val

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA tad sort by key/val

Signed-off-by: raver119 <raver119@gmail.com>

* start providing of backprop for pooling2d/3d

Signed-off-by: Yurii <yurii@skymind.io>

* Added atomicAdd for bool datatype.

* dynamic partition concept

Signed-off-by: raver119 <raver119@gmail.com>

* dynamic partition concept

Signed-off-by: raver119 <raver119@gmail.com>

* dynamic partition scalar CUDA

Signed-off-by: raver119 <raver119@gmail.com>

* important comment

Signed-off-by: raver119 <raver119@gmail.com>

* fix pooling2d/3d backprop helpers

Signed-off-by: Yurii <yurii@skymind.io>

* Added non-linear test with dynamic_partition.

* Improved test for dynamic_partition.

* dynamic_partition TAD concept

Signed-off-by: raver119 <raver119@gmail.com>

* - dynamic_partition TAD CUDA impl
- dynamic_partition TAD CPU fix

Signed-off-by: raver119 <raver119@gmail.com>

* - rewrite cpu code for usampling2d/3d
- write cuda code for usampling2d/3d

Signed-off-by: Yurii <yurii@skymind.io>

* dynamic_stitch CUDA vector case

Signed-off-by: raver119 <raver119@gmail.com>

* dynamic_stitch CUDA TAD case concept

Signed-off-by: raver119 <raver119@gmail.com>

* dynamic_stitch CUDA TAD case impl

Signed-off-by: raver119 <raver119@gmail.com>

* Added tests for dynamic_stitch 3D-4D cases.

* minor tests tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed type check for dynamic stitch.

* min/max bp

Signed-off-by: raver119 <raver119@gmail.com>

* rewrite code for upsampling2d/3d cpu

Signed-off-by: Yurii <yurii@skymind.io>

* reduce min/max/norm_max bp

Signed-off-by: raver119 <raver119@gmail.com>

* lup implementation. Additional enhancements.

* provide code for upsamling2d/3d backprop

Signed-off-by: Yurii <yurii@skymind.io>

* weightedCrossEntropyWithLogits

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed template math atomicMul for 64bit ints.

* Refactored dynamic_partition_bp op.

* inverseBroadcast fix

Signed-off-by: raver119 <raver119@gmail.com>

* DynamicPartitionBP test datatype fixed.

* - nd4j_atomicMul Windows fix
- cpu/NDArrayLambda.hpp excluded from CUDA

Signed-off-by: raver119 <raver119@gmail.com>
2019-06-27 18:37:04 +03:00

973 lines
58 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// implementation of operations for Simple Recurrent Unit: arXiv:1709.02755v2 [cs.CL] 12 Sep 2017
//
//@author Yurii Shyrma
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_sru)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/sru.h>
#include <MmulHelper.h>
#include <helpers/PointersManager.h>
namespace nd4j {
namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(sru, 5, 2, false, 0, 0) {
auto x = INPUT_VARIABLE(0); // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features
auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3*inSize x inSize]
auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [2*inSize]
auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0
auto mask = block.width() > 4 ? INPUT_VARIABLE(4) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize]
auto h = OUTPUT_VARIABLE(0); // cell outputs, [bS x inSize x time]
auto c = OUTPUT_VARIABLE(1); // cell states, [bS x inSize x time]
const int rank = x->rankOf(); // = 3
const auto bS = x->sizeAt(0);
const auto inSize = x->sizeAt(1);
const auto time = x->sizeAt(2);
// input shapes validation
REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf());
REQUIRE_TRUE(b->rankOf() == 1, 0, "SRU operation: wrong rank of biases array, expected is %i, but got %i instead !", 1, b->rankOf());
REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf());
if(mask)
REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf());
const std::string wShape = ShapeUtils::shapeAsString(w);
const std::string wCorrectShape = ShapeUtils::shapeAsString({3*inSize, inSize});
const std::string bShape = ShapeUtils::shapeAsString(b);
const std::string bCorrectShape = ShapeUtils::shapeAsString({2*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, inSize});
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
if(mask) {
const std::string maskShape = ShapeUtils::shapeAsString(mask);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
}
// xm = x * mask
auto xm = x;
if(mask) {
xm = new NDArray(x->getShapeInfo(), true, block.launchContext());
x->applyBroadcast(broadcast::Multiply, {0, 1}, mask, xm, nullptr);
}
// time loop
helpers::sruTimeLoop(block.launchContext(), xm, c0, w, b, h, c);
if(mask)
delete xm;
return Status::OK();
}
DECLARE_TYPES(sru) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_SHAPE_FN(sru) {
auto xShapeInfo = inputShape->at(0); // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features
auto wShapeInfo = inputShape->at(1); // W, 2d tensor of weights [3*inSize x inSize]
auto bShapeInfo = inputShape->at(2); // B, row of biases with twice length [2*inSize]
auto c0ShapeInfo = inputShape->at(3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0
Nd4jLong* maskShapeInfo = block.width() > 4 ? inputShape->at(4) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize]
const int rank = xShapeInfo[0]; // = 3
const int bS = xShapeInfo[1];
const int inSize = xShapeInfo[2];
const int time = xShapeInfo[3];
// input shapes validation
REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]);
REQUIRE_TRUE(bShapeInfo[0] == 1, 0, "SRU operation: wrong rank of biases array, expected is %i, but got %i instead !", 1, bShapeInfo[0]);
REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo[0]);
if(maskShapeInfo)
REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]);
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
const std::string wCorrectShape = ShapeUtils::shapeAsString({3*inSize, inSize});
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
const std::string bCorrectShape = ShapeUtils::shapeAsString({2*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, inSize});
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
if(maskShapeInfo) {
const std::string maskShape = ShapeUtils::shapeAsString(maskShapeInfo);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
}
Nd4jLong* newShapeInfo1 = nullptr;
ALLOCATE(newShapeInfo1, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x inSize x time]
newShapeInfo1[0] = rank;
newShapeInfo1[1] = bS;
newShapeInfo1[2] = inSize;
newShapeInfo1[3] = time;
ShapeUtils::updateStridesAndType(newShapeInfo1, xShapeInfo, shape::order(xShapeInfo));
ShapeDescriptor descriptor(newShapeInfo1);
RELEASE(newShapeInfo1, block.getWorkspace());
auto result = ConstantShapeHelper::getInstance()->createShapeInfo(descriptor);
return SHAPELIST(result, result);
}
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) {
auto x = INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features
auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K]
auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 × 2*K]
auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0
auto c = INPUT_VARIABLE(4); // C, [bS x K x N]
auto inGradCt = INPUT_VARIABLE(5); // [bS x K]
auto inGradH = INPUT_VARIABLE(6); // [bS x K x N]
NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x K]
bool applyMask = false;
if (block.width() > 7) {
mask = INPUT_VARIABLE(7);
applyMask = true;
}
auto gradX = OUTPUT_VARIABLE(0); // [bS x K x N]
auto gradW = OUTPUT_VARIABLE(1); // [bS x 3K x K]
auto gradB = OUTPUT_VARIABLE(2); // [1 x 2K]
auto gradInit = OUTPUT_VARIABLE(3); // [bS x K]
const int bS = x->shapeOf()[0];
const int K = x->shapeOf()[1];
const int N = x->shapeOf()[2]; // N - number of time steps
auto gradBias = NDArrayFactory::create_(x->ordering(), {bS, 2*K, N}, gradX->dataType(), block.launchContext());
auto gradU = NDArrayFactory::create_(x->ordering(), {bS, 3*K, N}, gradX->dataType(), block.launchContext());
auto gradHX = NDArrayFactory::create_(x->ordering(), {bS, K, N}, gradX->dataType(), block.launchContext());
auto gct = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext());
auto gradTanh = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext());
auto gradCt = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext());
auto ftMinus = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext());
auto rtMinus = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext());
auto temp1 = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext());
auto temp2 = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext());
// x = x * mask
if(applyMask)
x->applyBroadcast(broadcast::Multiply, {0, 1}, mask, x, nullptr); // apply mask
// multiplication matrix wi = matmul(w,x), U = WX
auto wi = MmulHelper::mmul(w, x, nullptr, 1., 0.); // U [bS x 3K x N]
auto wiZ = (*wi)({0,0, 0,K, 0,0}, true); // [bS x K x N]
auto wiF = (*wi)({0,0, K,2*K, 0,0}, true); // forget gate [bS x K x N]
auto wiR = (*wi)({0,0, 2*K,3*K, 0,0}, true); // reset gate [bS x K x N]
auto bF = (*b) ({0,0, 0,K }, true); // biases for forget gate [1 x K]
auto bR = (*b) ({0,0, K,2*K}, true); // biases for reset gate [1 x K]
auto gradBF = (*gradBias)({0,0, 0,K, 0,0}, true); // [bS x K x N]
auto gradBR = (*gradBias)({0,0, K,2*K, 0,0}, true); // [bS x K x N]
auto gradUZ = (*gradU) ({0,0, 0,K, 0,0}, true ); // [bS x K x N]
auto gradUF = (*gradU) ({0,0, K,2*K, 0,0}, true ); // [bS x K x N]
auto gradUR = (*gradU) ({0,0, 2*K,3*K, 0,0}, true ); // [bS x K x N]
NDArray* ct_1 = nullptr;
std::vector<Nd4jLong> idx = {0,0, 0,0, 0,0};
for (int t = N-1; t >=0 ; --t) {
// initialization
idx[4] = t;
idx[5] = t + 1;
auto xt = (*x)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto zt = wiZ(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto ft = wiF(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto rt = wiR(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto ct = (*c)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto inGradHt = (*inGradH)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto gradBRt = gradBR(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto gradBFt = gradBF(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto gradHXt = (*gradHX)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto gradUZt = gradUZ(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto gradUFt = gradUF(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto gradURt = gradUR(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
if(t != 0) {
idx[4] = t - 1;
idx[5] = t;
ct_1 = new NDArray((*c)(idx)); // previous c_{t-1}
}
else
ct_1 = c0;
///////////////// forward
// ft = sigmoid(ft + bf), rt = sigmoid(rt + bR)
ft.addRowVector(&bF, &ft);
rt.addRowVector(&bR, &rt);
ft.applyTransform(transform::Sigmoid, nullptr, nullptr);
rt.applyTransform(transform::Sigmoid, nullptr, nullptr);
// TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur );
ct.applyTransform(transform::Tanh, gct);
// ftMinus = 1-ft, rtMinus = 1-rt
ft.applyTransform(transform::OneMinus, ftMinus);
rt.applyTransform(transform::OneMinus, rtMinus);
///////////////// backward
// bR, *grad_brt_ptr = inGradHt * (g_ct - xt) * (1.0f - rt) * rt;
gct->applyPairwiseTransform(pairwise::Subtract, &xt, temp1, nullptr); // temp1 = (g_ct - xt)
rtMinus->applyPairwiseTransform(pairwise::Multiply, &rt, temp2, nullptr); // temp2 = (1.0f - rt) * rt;
temp1->applyPairwiseTransform(pairwise::Multiply, *temp2, nullptr); // temp1 = (g_ct - xt) * (1.0f - rt) * rt;
inGradHt.applyPairwiseTransform(pairwise::Multiply, temp1, &gradBRt, nullptr); // = inGradHt * (g_ct - xt) * (1.0f - rt) * rt;
// bF, TODO - tanh
// gradTanh = (1.0f - g_ct * g_ct);
gct->applyPairwiseTransform(pairwise::Multiply, gct, gradTanh, nullptr); // gradTanh = g_ct * g_ct
gradTanh->applyTransform(transform::OneMinus, gradTanh); // gradTanh = (1.0f - g_ct * g_ct)
// gradCt = inGradHt * rt * gradTanh
rt.applyPairwiseTransform(pairwise::Multiply, gradTanh, gradCt, nullptr); // gradCt = rt * gradTanh
inGradHt.applyPairwiseTransform(pairwise::Multiply, gradCt, gradCt, nullptr); // gradCt = inGradHt * rt * gradTanh
// gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft;
gradCt->applyPairwiseTransform(pairwise::Add, inGradCt, temp1, nullptr); // temp1 = (gradCt + inGradCt)
ct_1->applyPairwiseTransform(pairwise::Subtract, &zt, temp2, nullptr); // temp2 = (ct_1 - zt)
temp1->applyPairwiseTransform(pairwise::Multiply, ftMinus, temp1, nullptr); // temp1 = (gradCt + inGradCt)*(1-ft)
temp1->applyPairwiseTransform(pairwise::Multiply, &ft, temp1, nullptr); // temp1 = (gradCt + inGradCt)*(1-ft)*ft
temp1->applyPairwiseTransform(pairwise::Multiply, temp2, &gradBFt, nullptr); // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft;
// x_t (highway connection), gradHXt = inGradHt * (1.0f - rt);
inGradHt.applyPairwiseTransform(pairwise::Multiply, rtMinus, &gradHXt, nullptr);
// U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft);
rt.applyPairwiseTransform(pairwise::Multiply, gradTanh, temp1, nullptr); // temp1 = rt * grad_tanh
inGradHt.applyPairwiseTransform(pairwise::Multiply, temp1, temp1, nullptr); // temp1 = inGradHt * rt * grad_tanh
temp1->applyPairwiseTransform(pairwise::Add, inGradCt, temp1, nullptr); // temp1 = inGradHt * rt * grad_tanh + inGradCt
temp1->applyPairwiseTransform(pairwise::Multiply, ftMinus, &gradUZt, nullptr); // gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft);
gradUFt.assign(&gradBFt);
gradURt.assign(&gradBRt);
// c_{t-1}, inGradCt = (gradCt + inGradCt) * ft;
gradCt->applyPairwiseTransform(pairwise::Add, inGradCt, temp1, nullptr); // temp1 = (gradCt + inGradCt)
temp1->applyPairwiseTransform(pairwise::Multiply, &ft, inGradCt, nullptr); // inGradCt = (gradCt + inGradCt) * ft;
if(t != 0)
delete ct_1;
}
// gradInit
gradInit->assign(inGradCt);
// gradX
auto weightsT = w->transpose(); // [K x 3K]
MmulHelper::mmul(&weightsT, gradU, gradX, 1., 0.); // [bS x K x N]
gradX->applyPairwiseTransform(pairwise::Add, gradHX, gradX, nullptr); // + grad_highway_x
if(applyMask)
gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr); // apply mask
// gradB
auto temp3 = gradBias->reduceAlongDimension(reduce::Sum, {0,2}, false, true); // [1 x 2K]
gradB->assign(temp3);
// gradW [bS x 3K x K]
x->permutei({0, 2, 1}); // [bS x N x K]
MmulHelper::mmul(gradU, x, gradW, 1., 0.); // [bS x 3K x K]
delete gct; delete gradU; delete gradHX;
delete temp1; delete temp2; delete temp3; delete gradCt; delete wi;
delete gradTanh; delete ftMinus; delete rtMinus; delete gradBias;
return Status::OK();
}
DECLARE_TYPES(sru_bp) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_SHAPE_FN(sru_bp) {
auto inShape = inputShape->at(0); // [bS x inSize x time]
auto bS = inShape[1];
auto inSize = inShape[2];
auto time = inShape[3];
char order = (char)(inShape[9]);
ShapeDescriptor descriptor1(ArrayOptions::dataType(inShape), order, {bS, inSize, time});
ShapeDescriptor descriptor2(ArrayOptions::dataType(inShape), order, {bS, 3 * inSize, inSize});
ShapeDescriptor descriptor3(ArrayOptions::dataType(inShape), order, {1, 2 * inSize});
ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize});
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4));
}
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(sru_bi, 5, 2, true, 0, 0) {
auto x = INPUT_VARIABLE(0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features
auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize]
auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 × 4*inSize]
auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0
NDArray* mask = block.width() > 4 ? INPUT_VARIABLE(4) : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize]
auto ht = OUTPUT_VARIABLE(0); // h_t, [time x bS x 2*inSize]
auto ct = OUTPUT_VARIABLE(1); // c_t, [time x bS x 2*inSize]
// input shapes validation
const int rank = x->rankOf();
const Nd4jLong bS = x->sizeAt(1);
const Nd4jLong inSize = x->sizeAt(2) / 2;
REQUIRE_TRUE(x->rankOf() == rank, 0, "SRU_BI operation: wrong rank of input array, expected is %i, but got %i instead !", rank, x->rankOf());
REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf());
REQUIRE_TRUE(b->rankOf() <= rank-1, 0, "SRU_BI operation: wrong rank of biases array, expected is <=2, but got %i instead !", b->rankOf());
REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf());
if(mask)
REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf());
const std::string wShape = ShapeUtils::shapeAsString(w);
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
const std::string bShape = ShapeUtils::shapeAsString(b);
const std::string bCorrectShape = ShapeUtils::shapeAsString({1, 4*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
if(mask) {
const std::string maskShape = ShapeUtils::shapeAsString(mask);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
}
helpers::sruBI(block.launchContext(), x, w, b, c0, mask, ht, ct);
return Status::OK();
}
DECLARE_TYPES(sru_bi) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_SHAPE_FN(sru_bi) {
auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ]
auto wShapeInfo = inputShape->at(1);
auto bShapeInfo = inputShape->at(2);
auto c0ShapeInfo = inputShape->at(3);
Nd4jLong* maskShapeInfo = block.width() > 4 ? inputShape->at(4) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize]
const int rank = xShapeInfo[0]; // = 3
const Nd4jLong time = xShapeInfo[1];
const Nd4jLong bS = xShapeInfo[2];
const Nd4jLong inSize = xShapeInfo[3] / 2;
// input shapes validation
REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]);
REQUIRE_TRUE(bShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of biases array, expected is <=2, but got %i instead !", rank-1, bShapeInfo[0]);
REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo[0]);
if(maskShapeInfo)
REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]);
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
const std::string bCorrectShape = ShapeUtils::shapeAsString({1, 4*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
if(maskShapeInfo) {
const std::string maskShape = ShapeUtils::shapeAsString(maskShapeInfo);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
}
char order = shape::order(xShapeInfo);
ShapeDescriptor descriptor(ArrayOptions::dataType(xShapeInfo), order, {time, bS, 2 * inSize});
auto result = ConstantShapeHelper::getInstance()->createShapeInfo(descriptor);
return SHAPELIST(result, result);
}
DECLARE_TYPES(sru_bi_bp) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(sru_bi_bp, 8, 4, true, 0, 0) {
auto x = INPUT_VARIABLE(0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features
auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize]
auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 × 4*inSize]
auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0
auto ct = INPUT_VARIABLE(4); // C, [time x bS x 2*inSize]
auto inGradC0 = INPUT_VARIABLE(5); // [bS x 2*inSize]
auto inGradHt = INPUT_VARIABLE(6); // [time x bS x 2*inSize]
NDArray* mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize]
// input shapes validation
const int rank = x->rankOf();
const Nd4jLong time = x->sizeAt(0);
const Nd4jLong bS = x->sizeAt(1);
const Nd4jLong inSize = x->sizeAt(2) / 2;
REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf());
REQUIRE_TRUE(b->rankOf() <= rank-1, 0, "SRU_BI_BP operation: wrong rank of biases array, expected is <=2, but got %i instead !", b->rankOf());
REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf());
REQUIRE_TRUE(ct->rankOf() == rank, 0, "SRU_BI_BP operation: wrong rank of state array, expected is %i, but got %i instead !", rank, ct->rankOf());
REQUIRE_TRUE(inGradC0->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of gradient c0, expected is %i, but got %i instead !", rank-1, inGradC0->rankOf());
REQUIRE_TRUE(inGradHt->rankOf() == rank, 0, "SRU_BI_BP operation: wrong rank of gradient ht, expected is %i, but got %i instead !", rank, inGradHt->rankOf());
if(mask)
REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf());
const std::string wShape = ShapeUtils::shapeAsString(w);
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
const std::string bShape = ShapeUtils::shapeAsString(b);
const std::string bCorrectShape = ShapeUtils::shapeAsString({1, 4*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
const std::string ctShape = ShapeUtils::shapeAsString(ct);
const std::string ctCorrectShape = ShapeUtils::shapeAsString({time, bS, 2*inSize});
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
REQUIRE_TRUE(ctShape == ctCorrectShape, 0, "SRU_BI operation: wrong shape of state array, expected is %s, but got %s instead !", ctCorrectShape.c_str(), ctShape.c_str());
if(mask) {
const std::string maskShape = ShapeUtils::shapeAsString(mask);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
}
auto gradI = OUTPUT_VARIABLE(0); // [time x bS x 2*inSize]
auto gradW = OUTPUT_VARIABLE(1); // [time x 2*inSize x 6*inSize]
auto gradB = OUTPUT_VARIABLE(2); // [1 x 4*inSize]
auto gradC0 = OUTPUT_VARIABLE(3); // [bS x 2*inSize]
helpers::sruBIBP(block.launchContext(), x, w, b, c0, ct, inGradC0, inGradHt, mask, gradI, gradW, gradB, gradC0);
return Status::OK();
}
DECLARE_SHAPE_FN(sru_bi_bp) {
auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ]
auto wShapeInfo = inputShape->at(1);
auto bShapeInfo = inputShape->at(2);
auto c0ShapeInfo = inputShape->at(3);
auto ctShapeInfo = inputShape->at(4);
auto inGradC0ShapeInfo = inputShape->at(5);
auto inGradHtShapeInfo = inputShape->at(6);
Nd4jLong* maskShapeInfo = block.width() > 7 ? inputShape->at(7) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize]
// input shapes validation
const int rank = xShapeInfo[0];
const Nd4jLong time = xShapeInfo[1];
const Nd4jLong bS = xShapeInfo[2];
const Nd4jLong inSize = xShapeInfo[3] / 2;
REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]);
REQUIRE_TRUE(bShapeInfo[0] <= rank-1, 0, "SRU_BI_BP operation: wrong rank of biases array, expected is <=2, but got %i instead !", bShapeInfo);
REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo);
REQUIRE_TRUE(ctShapeInfo[0] == rank, 0, "SRU_BI_BP operation: wrong rank of state array, expected is %i, but got %i instead !", rank, ctShapeInfo);
REQUIRE_TRUE(inGradC0ShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of gradient c0, expected is %i, but got %i instead !", rank-1, inGradC0ShapeInfo[0]);
REQUIRE_TRUE(inGradHtShapeInfo[0] == rank, 0, "SRU_BI_BP operation: wrong rank of gradient ht, expected is %i, but got %i instead !", rank, inGradHtShapeInfo[0]);
if(maskShapeInfo)
REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]);
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
const std::string bCorrectShape = ShapeUtils::shapeAsString({1, 4*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
const std::string ctShape = ShapeUtils::shapeAsString(ctShapeInfo);
const std::string ctCorrectShape = ShapeUtils::shapeAsString({time, bS, 2*inSize});
const std::string inGradC0Shape = ShapeUtils::shapeAsString(inGradC0ShapeInfo);
const std::string inGradC0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
const std::string inGradHtShape = ShapeUtils::shapeAsString(inGradHtShapeInfo);
const std::string inGradHtCorrectShape = ShapeUtils::shapeAsString({time, bS, 2*inSize});
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
REQUIRE_TRUE(ctShape == ctCorrectShape, 0, "SRU_BI operation: wrong shape of state array, expected is %s, but got %s instead !", ctCorrectShape.c_str(), ctShape.c_str());
REQUIRE_TRUE(inGradC0Shape == inGradC0CorrectShape, 0, "SRU_BI operation: wrong shape of gradient c0 array, expected is %s, but got %s instead !", inGradC0CorrectShape.c_str(), inGradC0Shape.c_str());
REQUIRE_TRUE(inGradHtShape == inGradHtCorrectShape, 0, "SRU_BI operation: wrong shape of gradient ht array, expected is %s, but got %s instead !", inGradHtCorrectShape.c_str(), inGradHtShape.c_str());
if(maskShapeInfo) {
const std::string maskShape = ShapeUtils::shapeAsString(maskShapeInfo);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
}
const char order = shape::order(xShapeInfo);
ShapeDescriptor descriptor1(ArrayOptions::dataType(xShapeInfo), order, {time, bS, 2 * inSize});
ShapeDescriptor descriptor2(ArrayOptions::dataType(xShapeInfo), order, {time, 2 * inSize, 6 * inSize});
ShapeDescriptor descriptor3(ArrayOptions::dataType(xShapeInfo), order, {1, 4 * inSize});
ShapeDescriptor descriptor4(ArrayOptions::dataType(xShapeInfo), order, {bS, 2 * inSize});
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4));
}
}
}
#endif
//////////////////////////////////////////////////////////////////////////
/**
* Implementation of operations for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi
*
* Input arrays:
* 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features
* 1: 2d tensor of weights [3K x K]
* 2: row of biases with twice length [1 × 2K]
* 3: 2d tensor of previous cell state [bS x K]
* 4: optional, 2d tensor of dropout mask [bS x K]
*
* Output arrays:
* 0: 3d tensor of cell output [bS x K x N]
* 1: 3d tensor of cell state [bS x K x N]
*/
// #if NOT_EXCLUDED(OP_sru)
// DECLARE_CUSTOM_OP(sru_old, 5, 2, false, 0, 0);
//////////////////////////////////////////////////////////////////////////
/**
* Implementation of operation for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi
*
* Input arrays:
* 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features
* 1: 2d tensor of weights [3K x K]
* 2: row of biases with twice length [1 × 2K]
* 3: 2d tensor of previous cell state [bS x K]
* 4: optional, 2d tensor of dropout mask [bS x K]
*
* Output arrays:
* 0: 3d tensor of cell output [bS x K x N]
* 1: 3d tensor of cell state [bS x K x N]
*/
// #if NOT_EXCLUDED(OP_sru_logic)
// DECLARE_CUSTOM_OP(sru_logic, 5, 2, false, 0, 0);
// #endif
//////////////////////////////////////////////////////////////////////////
/**
* Implementation of operation for back propagation in Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi
*
* Input arrays:
* 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features
* 1: 2d tensor of weights [3K x K]
* 2: row of biases with twice length [1 × 2K]
* 3: 2d tensor of previous cell state [bS x K]
* 4: 3d tensor of cell state [bS x K x N]
* 5: 2d tensor of cell state gradients [bS x K]
* 6: 3d tensor of state output gradients [bS x K x N]
* 7: optional, 2d tensor of dropout mask [bS x K]
*
* Output arrays:
* 0: 3d tensor of input gradients [bS x K x N]
* 1: 3d tensor of weights gradients [bS x 3K x K]
* 2: 2d, row of biases gradients [1 x 2K]
* 3: 2d, tensor of state gradients [bS x K]
*/
// #if NOT_EXCLUDED(OP_sru_logic)
// DECLARE_CUSTOM_OP(sru_bp_logic,8, 4, true, 0, 0);
// #endif
// return 2d array evaluated though last dimension interval t1-t2
// static NDArray* timestep(const NDArray* arr, const int t1, const int t2) {
// NDArray* result = new NDArray((*arr)({0,0, 0,0, t1,t2}, true));
// result->reshapei(result->ordering(), {arr->shapeOf()[0], arr->shapeOf()[1]} );
// return result;
// }
/////////////////////////////////////////////////////////////////////////
// CUSTOM_OP_IMPL(sru_logic, 5, 2, false, 0, 0) {
// auto input = INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features
// auto weights = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K]
// auto bias = INPUT_VARIABLE(2); // B, row of biases with twice length [1 × 2*K]
// auto init = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0
// NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x K]
// bool applyMask = false;
// if (block.width() > 4) {
// mask = INPUT_VARIABLE(4);
// applyMask = true;
// }
// auto output = OUTPUT_VARIABLE(0); // h_t, [bS x K x N]
// auto state = OUTPUT_VARIABLE(1); // c_t, [bS x K x N]
// const int bS = input->shapeOf()[0]; // bS - batch size
// const int K = input->shapeOf()[1]; // K - number of features
// const int N = input->shapeOf()[2]; // N - number of time steps
// const auto wi = mmul(*weights, *input); // U [bS x 3K x N]
// const auto bF = (*bias)({0,0, 0, K}); // biases for forget gate [1 x K]
// const auto bR = (*bias)({0,0, K,2*K}); // biases for reset gate [1 x K]
// NDArray xt(input->dataType(), block.launchContext());
// NDArray zt(input->dataType(), block.launchContext());
// NDArray ft(input->dataType(), block.launchContext());
// NDArray rt(input->dataType(), block.launchContext());
// NDArray ht(input->dataType(), block.launchContext());
// NDArray ct = *init;
// NDArray gct(state->ordering(), {bS, K}, input->dataType(), block.launchContext());
// NDArray xmt = *input;
// // input = input * mask
// if(applyMask)
// xmt.applyBroadcast(broadcast::Multiply, {0, 1}, mask, &xmt, nullptr);
// for (int t = 0; t < N; ++t) {
// xt = xmt({0,0, 0,0, t,t+1}); xt.reshapei(xt.ordering(), {bS, K}); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
// zt = wi({0,0, 0, K, t,t+1}); zt.reshapei(zt.ordering(), {bS, K}); // [bS x 3K x N] -> [bS x K x 1] -> [bS x K]
// ft = wi({0,0, K, 2*K, t,t+1}); ft.reshapei(ft.ordering(), {bS, K}); // [bS x 3K x N] -> [bS x K x 1] -> [bS x K]
// rt = wi({0,0, 2*K,3*K, t,t+1}); rt.reshapei(rt.ordering(), {bS, K}); // [bS x 3K x N] -> [bS x K x 1] -> [bS x K]
// ft = sigmoid_(ft + bF);
// rt = sigmoid_(rt + bR);
// ct = ft * (ct - zt) + zt;
// // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur );
// ct.applyTransform(transform::Tanh, &gct);
// ht = rt * (gct - xt) + xt;
// // save results
// (*output)({0,0, 0,0, t,t+1}, true).assign(ht);
// (*state)({0,0, 0,0, t,t+1}, true).assign(ct);
// }
// return Status::OK();
// }
// DECLARE_TYPES(sru_logic) {
// getOpDescriptor()
// ->setAllowedInputTypes(nd4j::DataType::ANY)
// ->setAllowedOutputTypes({ALL_FLOATS});
// }
// DECLARE_SHAPE_FN(sru_logic) {
// auto inShape = inputShape->at(0); // [bS x K x N]
// int rank = inShape[0]; // = 3
// int size = rank*2 + 4;
// int bS = inShape[1];
// int K = inShape[2];
// int N = inShape[3];
// char order = (char)(inShape[size-1]);
// Nd4jLong* newShapeInfo1 = nullptr;
// ALLOCATE(newShapeInfo1, block.getWorkspace(), size, Nd4jLong);
// newShapeInfo1[0] = rank;
// newShapeInfo1[1] = bS;
// newShapeInfo1[2] = K;
// newShapeInfo1[3] = N;
// ShapeUtils::updateStridesAndType(newShapeInfo1, inShape, order);
// auto result = CONSTANT(newShapeInfo1);
// return SHAPELIST(result, result);
// }
// //////////////////////////////////////////////////////////////////////////
// CUSTOM_OP_IMPL(sru_old, 5, 2, false, 0, 0) {
// auto x = INPUT_VARIABLE(0); // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features
// auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x inSize]
// auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 × 2*inSize]
// auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0
// NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x inSize]
// bool applyMask = false;
// if (block.width() > 4) {
// mask = INPUT_VARIABLE(4);
// applyMask = true;
// }
// auto h = OUTPUT_VARIABLE(0); // h_t, [bS x inSize x time]
// auto state = OUTPUT_VARIABLE(1); // c_t, [bS x inSize x time]
// const int bS = x->shapeOf()[0]; // bS - batch size
// const int inSize = x->shapeOf()[1]; // inSize - number of features
// const int time = x->shapeOf()[2]; // time - number of time steps
// // multiplication matrix = matmul(w,x)
// auto wi = MmulHelper::mmul(w, x, nullptr, 1., 0.); // U [bS x 3K x time]
// auto wiZ = (*wi)({0,0, 0,inSize, 0,0}, true); // [bS x inSize x time]
// auto wiF = (*wi)({0,0, inSize,2*inSize, 0,0}, true); // forget gate [bS x inSize x time]
// auto wiR = (*wi)({0,0, 2*inSize,3*inSize, 0,0}, true); // reset gate [bS x inSize x time]
// auto bF = (*b) ({0,0, 0,inSize }, true); // biases for forget gate [1 x inSize]
// auto bR = (*b) ({0,0, inSize,2*inSize}, true); // biases for reset gate [1 x inSize]
// NDArray* xt(nullptr), *zt(nullptr), *ft(nullptr), *rt(nullptr), *ct(nullptr), *ht(nullptr);
// auto ct_1 = c0->dup(c0->ordering());
// auto gct = NDArrayFactory::create_(state->ordering(), {bS, inSize}, state->dataType(), state->getContext());
// auto xmt = x->dup(x->ordering());
// // x = x * mask
// if(applyMask)
// xmt->applyBroadcast(broadcast::Multiply, {0, 1}, mask, xmt, nullptr); // apply mask
// for (int t = 0; t < time; ++t) {
// xt = timestep(xmt, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize]
// zt = timestep(&wiZ, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize]
// ft = timestep(&wiF, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize]
// rt = timestep(&wiR, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize]
// ct = timestep(state, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize]
// ht = timestep(h, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize]
// // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR)
// ft->addRowVector(&bF, ft);
// rt->addRowVector(&bR, rt);
// ft->applyTransform(transform::Sigmoid, ft, nullptr);
// rt->applyTransform(transform::Sigmoid, rt, nullptr);
// // ct = ft * c_t-1 + (1 - ft) * zt,
// ft->applyPairwiseTransform(pairwise::Multiply, ct_1, ct, nullptr);
// ft->applyTransform(transform::OneMinus, ft);
// ft->applyPairwiseTransform(pairwise::Multiply, *zt, nullptr);
// ct->applyPairwiseTransform(pairwise::Add, *ft, nullptr);
// // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur );
// ct->applyTransform(transform::Tanh, gct);
// // ht = rt * gct + (1 - rt) * xt
// rt->applyPairwiseTransform(pairwise::Multiply, gct, ht, nullptr);
// rt->applyTransform(transform::OneMinus, rt);
// rt->applyPairwiseTransform(pairwise::Multiply, *xt, nullptr);
// ht->applyPairwiseTransform(pairwise::Add, *rt, nullptr);
// delete xt; delete zt; delete ft; delete rt; delete ht; delete ct_1;
// ct_1 = ct;
// }
// delete wi; delete ct_1; delete gct; delete xmt;
// return Status::OK();
// }
// DECLARE_TYPES(sru_old) {
// getOpDescriptor()
// ->setAllowedInputTypes(nd4j::DataType::ANY)
// ->setAllowedOutputTypes({ALL_FLOATS});
// }
// DECLARE_SHAPE_FN(sru_old) {
// auto inShape = inputShape->at(0); // [bS x inSize x time]
// int rank = inShape[0]; // = 3
// int size = rank*2 + 4;
// auto bS = inShape[1];
// auto inSize = inShape[2];
// int time = inShape[3];
// char order = (char)(inShape[size-1]);
// Nd4jLong *newShapeInfo1 = nullptr;
// ALLOCATE(newShapeInfo1, block.getWorkspace(), size, Nd4jLong);
// newShapeInfo1[0] = rank;
// newShapeInfo1[1] = bS;
// newShapeInfo1[2] = inSize;
// newShapeInfo1[3] = time;
// ShapeUtils::updateStridesAndType(newShapeInfo1, inShape, order);
// auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newShapeInfo1));
// RELEASE(newShapeInfo1, block.getWorkspace());
// return SHAPELIST(result, result);
// }
// static NDArray sigmoid_(const NDArray& arr) {
// NDArray result(arr.getShapeInfo(), false, arr.getContext());
// (const_cast<NDArray&>(arr)).applyTransform(transform::Sigmoid, &result);
// return result;
// }
//////////////////////////////////////////////////////////////////////////
// CUSTOM_OP_IMPL(sru_bp_logic, 8, 4, true, 0, 0) {
// auto x = INPUT_VARIABLE(0); // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features
// auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3*inSize x inSize]
// auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 × 2*inSize]
// auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0
// auto c = INPUT_VARIABLE(4); // C, [bS x inSize x time]
// auto inGradCt = INPUT_VARIABLE(5); // [bS x inSize]
// auto inGradH = INPUT_VARIABLE(6); // [bS x inSize x time]
// auto mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize]
// auto gradX = OUTPUT_VARIABLE(0); // [bS x inSize x time]
// auto gradW = OUTPUT_VARIABLE(1); // [bS x 3*inSize x inSize]
// auto gradB = OUTPUT_VARIABLE(2); // [2*inSize]
// auto gradInit = OUTPUT_VARIABLE(3); // [bS x inSize]
// // input shapes validation
// const int rank = 3;
// REQUIRE_TRUE(x->rankOf() == rank, 0, "SRU_BP operation: wrong rank of input array, expected is %i, but got %i instead !", rank, x->rankOf());
// REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf());
// REQUIRE_TRUE(b->rankOf() <= 2, 0, "SRU_BP operation: wrong rank of biases array, expected is <=2, but got %i instead !", b->rankOf());
// REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf());
// REQUIRE_TRUE(c->rankOf() == rank, 0, "SRU_BP operation: wrong rank of cell states array, expected is %i, but got %i instead !", rank, c->rankOf());
// REQUIRE_TRUE(inGradCt->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of array of cell state gradient, expected is %i, but got %i instead !", rank-1, inGradCt->rankOf());
// REQUIRE_TRUE(inGradH->rankOf() == rank, 0, "SRU_BP operation: wrong rank of array of cell outputs gradients, expected is %i, but got %i instead !", rank, inGradH->rankOf());
// if(mask)
// REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf());
// const int bS = x->shapeOf()[0];
// const int inSize = x->shapeOf()[1];
// const int time = x->shapeOf()[2]; // time - number of time steps
// const std::string wShape = ShapeUtils::shapeAsString(w);
// const std::string wCorrectShape = ShapeUtils::shapeAsString({3*inSize, inSize});
// // const std::string bShape = ShapeUtils::shapeAsString(b);
// // const std::string bCorrectShape = ShapeUtils::shapeAsString({2*inSize});
// const std::string c0Shape = ShapeUtils::shapeAsString(c0);
// const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, inSize});
// const std::string cShape = ShapeUtils::shapeAsString(c);
// const std::string cCorrectShape = ShapeUtils::shapeAsString({bS, inSize, time});
// const std::string inGradCtShape = ShapeUtils::shapeAsString(inGradCt);
// const std::string inGradCtCorrectShape = ShapeUtils::shapeAsString({bS, inSize});
// const std::string inGradHShape = ShapeUtils::shapeAsString(inGradH);
// const std::string inGradHCorrectShape = ShapeUtils::shapeAsString({bS, inSize, time});
// REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BP operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
// // REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
// REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BP operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
// REQUIRE_TRUE(cShape == cCorrectShape, 0, "SRU_BP operation: wrong shape of cell states array, expected is %s, but got %s instead !", cCorrectShape.c_str(), cShape.c_str());
// REQUIRE_TRUE(inGradCtShape == inGradCtCorrectShape, 0, "SRU_BP operation: wrong shape of array of cell state gradient, expected is %s, but got %s instead !", inGradCtCorrectShape.c_str(), inGradCtShape.c_str());
// REQUIRE_TRUE(inGradHShape == inGradHCorrectShape, 0, "SRU_BP operation: wrong shape of array of cell outputs gradients, expected is %s, but got %s instead !", inGradHCorrectShape.c_str(), inGradHShape.c_str());
// if(mask) {
// const std::string maskShape = ShapeUtils::shapeAsString(mask);
// REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BP operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
// }
// const auto bF = (*b)({0,0, 0, inSize}); // biases for forget gate [1 x inSize]
// const auto bR = (*b)({0,0, inSize,2*inSize}); // biases for reset gate [1 x inSize]
// NDArray gradBias(x->ordering(), {bS, 2*inSize, time}, x->dataType(), block.launchContext());
// NDArray gradU (x->ordering(), {bS, 3*inSize, time}, x->dataType(), block.launchContext());
// NDArray gradHX (x->ordering(), {bS, inSize, time}, x->dataType(), block.launchContext());
// NDArray gct (c->ordering(), {bS, inSize}, x->dataType(), block.launchContext());
// // x = x * mask
// if(mask)
// x->applyBroadcast(broadcast::Multiply, {0, 1}, mask, x, nullptr); // apply mask
// // multiplication matrix wi = matmul(w,x), U = WX
// const auto wi = mmul(*w, *x); // U [bS x 3K x time]
// for (int t = time-1; t >=0 ; --t) {
// // initialization
// auto xt = (*x)({0,0, 0,0, t,t+1}); // [bS x inSize x time] -> [bS x inSize]
// auto zt = wi({0,0, 0, inSize, t,t+1}); // [bS x 3K x time] -> [bS x inSize]
// auto ft = wi({0,0, inSize, 2*inSize, t,t+1}); // [bS x 3K x time] -> [bS x inSize]
// auto rt = wi({0,0, 2*inSize,3*inSize, t,t+1}); // [bS x 3K x time] -> [bS x inSize]
// auto ct = (*c)({0,0, 0,0, t,t+1}); // [bS x inSize x time] -> [bS x inSize]
// auto inGradHt = (*inGradH)({ 0,0, 0,0, t,t+1}); // [bS x inSize x time] -> [bS x inSize]
// auto ct_1 = t ? (*c)({ 0,0, 0,0, t-1,t}) : *c0; // previous c_{t-1}
// ///////////////// forward
// // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR)
// ft = sigmoid_(ft + bF);
// rt = sigmoid_(rt + bR);
// // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur );
// ct.applyTransform(transform::Tanh, &gct);
// ///////////////// backward
// // bR, *grad_brt_ptr = inGradHt * (g_ct - xt) * (1.0f - rt) * rt;
// // ftMinus = -ft + (T)1.;
// NDArray ftMinus = 1. - ft;
// NDArray rtMinus = 1. - rt;
// NDArray gradBRt = inGradHt * (gct - xt) * rtMinus * rt;
// // bF, TODO - tanh
// NDArray gradTanh = 1. - gct * gct;
// NDArray gradCt = inGradHt * rt * gradTanh;
// NDArray gradBFt = (gradCt + *inGradCt) * (ct_1 - zt) * ftMinus * ft;
// // x_t (highway connection), gradHXt = inGradHt * (1.0f - rt);
// NDArray gradHXt = inGradHt * rtMinus;
// // U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft);
// NDArray gradUZt = (inGradHt * rt * gradTanh + *inGradCt) * ftMinus;
// // c_{t-1}, inGradCt = (gradCt + inGradCt) * ft;
// *inGradCt = (gradCt + *inGradCt) * ft;
// // save results
// gradBias({0,0, 0,inSize, t,t+1}, true).assign(gradBFt);
// gradBias({0,0, inSize,2*inSize, t,t+1}, true).assign(gradBRt);
// gradU({0,0, 0,inSize, t,t+1}, true).assign(gradUZt);
// gradU({0,0, inSize,2*inSize, t,t+1}, true).assign(gradBFt);
// gradU({0,0, 2*inSize, 3*inSize, t,t+1}, true).assign(gradBRt);
// gradHX({0,0, 0,0, t,t+1}, true).assign(gradHXt);
// }
// // gradInit
// gradInit->assign(inGradCt);
// // gradX
// w->transposei(); // [inSize x 3K]
// gradX->assign( mmul(*w, gradU) + gradHX);
// if(mask)
// gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr); // apply mask
// // gradB
// gradBias.reduceAlongDimension(reduce::Sum, gradB, {0,2}, false, true); // [1 x 2K]
// // gradW [bS x 3K x inSize]
// x->permutei({0, 2, 1}); // [bS x time x inSize]
// gradW->assign( mmul(gradU, *x) );
// return Status::OK();
// }
// DECLARE_TYPES(sru_bp_logic) {
// getOpDescriptor()
// ->setAllowedInputTypes(nd4j::DataType::ANY)
// ->setAllowedOutputTypes({ALL_FLOATS});
// }
// DECLARE_SHAPE_FN(sru_bp_logic) {
// auto inShape = inputShape->at(0); // [bS x inSize x time]
// auto bS = inShape[1];
// auto inSize = inShape[2];
// auto time = inShape[3];
// char order = shape::order(inShape);
// ShapeDescriptor descriptor1(ArrayOptions::dataType(inShape), order, {bS, inSize, time});
// ShapeDescriptor descriptor2(ArrayOptions::dataType(inShape), order, {bS, 3 * inSize, inSize});
// ShapeDescriptor descriptor3(ArrayOptions::dataType(inShape), order, {1, 2 * inSize});
// ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize});
// return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4));
// }