/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // created by Yurii Shyrma on 15.02.2018 // #include #if NOT_EXCLUDED(OP_gru) #include #include namespace sd { namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(gru, 5, 1, false, 0, 0) { auto x = INPUT_VARIABLE(0); // input [time x bS x iS] auto h0 = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS x nU] auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [iS x 3*nU] auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nU x 3*nU] auto b = INPUT_VARIABLE(4); // biases, [3*nU] auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x nU], that is per each time step const int rank = x->rankOf(); // = 3 const int time = x->sizeAt(0); const int bS = x->sizeAt(1); const int iS = x->sizeAt(2); const int nU = h0->sizeAt(1); const std::vector h0CorrectShape = {bS, nU}; const std::vector wxCorrectShape = {iS, 3*nU}; const std::vector whCorrectShape = {nU, 3*nU}; const std::vector bCorrectShape = {3*nU}; REQUIRE_TRUE(h0->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(h0).c_str()); REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); helpers::gruTimeLoop(block.launchContext(), x, h0, Wx, Wh, b, h); return Status::OK(); } DECLARE_TYPES(gru) { getOpDescriptor() ->setAllowedInputTypes(sd::DataType::ANY) ->setAllowedOutputTypes({ALL_FLOATS}); } DECLARE_SHAPE_FN(gru) { const auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize] const auto h0ShapeInfo = inputShape->at(1); // initial cell output [bS x numUnits], that is at time step t=0 const auto WxShapeInfo = inputShape->at(2); // input-to-hidden weights, [inSize x 3*numUnits] const auto WhShapeInfo = inputShape->at(3); // hidden-to-hidden weights, [numUnits x 3*numUnits] const auto bShapeInfo = inputShape->at(4); // biases, [3*numUnits] const int rank = shape::rank(xShapeInfo); // = 3 const auto time = xShapeInfo[1]; const auto bS = xShapeInfo[2]; const auto inSize = xShapeInfo[3]; const auto numUnits = h0ShapeInfo[2]; const std::vector h0CorrectShape = {bS, numUnits}; const std::vector wxCorrectShape = {inSize, 3*numUnits}; const std::vector whCorrectShape = {numUnits, 3*numUnits}; const std::vector bCorrectShape = {3*numUnits}; REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(WxShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); // evaluate output shapeInfo Nd4jLong *hShapeInfo(nullptr); ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); hShapeInfo[0] = rank; hShapeInfo[1] = time; hShapeInfo[2] = bS; hShapeInfo[3] = numUnits; ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, shape::order(h0ShapeInfo)); return SHAPELIST(hShapeInfo); } } } #endif