2021-02-01 13:31:45 +01:00
/* ******************************************************************************
*
2019-06-06 14:21:15 +02:00
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
2021-02-01 13:31:45 +01:00
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership .
2019-06-06 14:21:15 +02:00
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author GS <sgazeos@gmail.com>
//
# include <ops/declarable/CustomOperations.h>
2020-03-02 10:49:41 +01:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
CUSTOM_OP_IMPL ( relu_layer , 3 , 1 , false , 0 , 0 ) {
auto x = INPUT_VARIABLE ( 0 ) ;
auto w = INPUT_VARIABLE ( 1 ) ;
auto b = INPUT_VARIABLE ( 2 ) ;
REQUIRE_TRUE ( x - > isMatrix ( ) , 0 , " relu_layer: x argument should be a 2D tensor, but got rank %i instead! " , x - > rankOf ( ) ) ;
REQUIRE_TRUE ( w - > isMatrix ( ) , 0 , " relu_layer: weights argument should be a 2D tensor, but got rank %i instead! " , w - > rankOf ( ) ) ;
REQUIRE_TRUE ( b - > isVector ( ) , 0 , " relu_layer: biases argument should be a 1D tensor, but got rank %i instead! " , b - > rankOf ( ) ) ;
REQUIRE_TRUE ( b - > lengthOf ( ) = = w - > sizeAt ( 1 ) , 0 , " relu_layer: biases array length should match to columns of weights matrix, however got length = %i and columns = %i! " , b - > lengthOf ( ) , w - > sizeAt ( 1 ) ) ;
2020-02-13 18:59:35 +01:00
REQUIRE_TRUE ( x - > sizeAt ( 1 ) = = w - > sizeAt ( 0 ) , 0 , " relu_layer: number of x columns should match to row number of weights matrix, but got x_columns = %i and weights_rows = %i! " , x - > sizeAt ( 1 ) , w - > sizeAt ( 0 ) ) ;
2019-12-20 20:35:39 +01:00
2019-06-06 14:21:15 +02:00
auto output = OUTPUT_VARIABLE ( 0 ) ;
2020-03-02 10:49:41 +01:00
sd : : ops : : xw_plus_b op ;
2020-02-13 18:59:35 +01:00
auto status = op . execute ( { x , w , b } , { output } ) ;
REQUIRE_TRUE ( Status : : OK ( ) = = status , 0 , " relu_layer: xw_plus_b op failed on input data. " ) ;
2019-06-06 14:21:15 +02:00
auto scalar = block . numT ( ) > 0 ? block . getTArguments ( ) - > at ( 0 ) : 0.0 ;
2020-03-02 10:49:41 +01:00
output - > applyScalar ( sd : : scalar : : RELU , scalar , * output ) ;
2019-06-06 14:21:15 +02:00
return Status : : OK ( ) ;
}
DECLARE_SHAPE_FN ( relu_layer ) {
auto inShape = inputShape - > at ( 0 ) ;
auto weightsShape = inputShape - > at ( 1 ) ;
auto outputShape = ShapeUtils : : matrixProductShape ( inShape , weightsShape , false , false , ArrayOptions : : dataType ( inShape ) , block . getWorkspace ( ) ) ;
2019-12-20 20:35:39 +01:00
2020-05-09 07:06:14 +02:00
return SHAPELIST ( outputShape ) ;
2019-06-06 14:21:15 +02:00
}
DECLARE_TYPES ( relu_layer ) {
getOpDescriptor ( )
2020-03-02 10:49:41 +01:00
- > setAllowedInputTypes ( sd : : DataType : : ANY )
2019-06-06 14:21:15 +02:00
// ->setAllowedInputTypes(1, {ALL_FLOATS})
- > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
}
}