2021-02-01 13:31:45 +01:00
/* ******************************************************************************
*
2019-06-06 14:21:15 +02:00
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
2021-02-01 13:31:45 +01:00
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership .
2019-06-06 14:21:15 +02:00
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// Created by GS <sgazeos@gmail.com> at 2/26/2018
//
2020-03-02 10:49:41 +01:00
# include <system/op_boilerplate.h>
2019-06-06 14:21:15 +02:00
# include <ops/declarable/CustomOperations.h>
# include <ops/declarable/helpers/lup.h>
# if NOT_EXCLUDED(OP_matrix_determinant)
2020-03-02 10:49:41 +01:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
CUSTOM_OP_IMPL ( matrix_determinant , 1 , 1 , false , 0 , 0 ) {
auto input = INPUT_VARIABLE ( 0 ) ;
auto output = OUTPUT_VARIABLE ( 0 ) ;
REQUIRE_TRUE ( input - > rankOf ( ) > = 2 , 0 , " matrix_determinant: The rank of input array should not less than 2, but %i is given " , input - > rankOf ( ) ) ;
REQUIRE_TRUE ( input - > sizeAt ( - 1 ) = = input - > sizeAt ( - 2 ) , 0 , " matrix_determinant: The last two dimmensions should be equal, but %i and %i are given " , input - > sizeAt ( - 1 ) , input - > sizeAt ( - 2 ) ) ;
return helpers : : determinant ( block . launchContext ( ) , input , output ) ;
}
DECLARE_SHAPE_FN ( matrix_determinant ) {
auto inShape = inputShape - > at ( 0 ) ;
2020-05-09 07:06:14 +02:00
Nd4jLong const * determinantShape ;
2019-06-06 14:21:15 +02:00
int targetRank = shape : : rank ( inShape ) - 2 ; // last two dimensions will be reduced to scalar
if ( targetRank = = 0 ) { // scalar only
2020-06-06 14:26:55 +02:00
determinantShape = ConstantShapeHelper : : getInstance ( ) . scalarShapeInfo ( ArrayOptions : : dataType ( inShape ) ) ;
2019-06-06 14:21:15 +02:00
}
2019-07-12 10:51:51 +02:00
else if ( targetRank = = 1 ) { // vector
2020-06-06 14:26:55 +02:00
determinantShape = ConstantShapeHelper : : getInstance ( ) . vectorShapeInfo ( shape : : sizeAt ( inShape , 0 ) , ArrayOptions : : dataType ( inShape ) ) ;
2019-06-06 14:21:15 +02:00
}
2019-07-12 10:51:51 +02:00
else { // only two last dimensions are excluded
2020-06-06 14:26:55 +02:00
determinantShape = ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( ArrayOptions : : dataType ( inShape ) , shape : : order ( inShape ) , targetRank , shape : : shapeOf ( inShape ) ) ;
2019-06-06 14:21:15 +02:00
}
return SHAPELIST ( determinantShape ) ;
}
DECLARE_TYPES ( matrix_determinant ) {
getOpDescriptor ( )
2020-03-02 10:49:41 +01:00
- > setAllowedInputTypes ( sd : : DataType : : ANY )
2019-06-06 14:21:15 +02:00
- > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
}
}
# endif
# if NOT_EXCLUDED(OP_log_matrix_determinant)
2020-03-02 10:49:41 +01:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
DECLARE_TYPES ( log_matrix_determinant ) {
getOpDescriptor ( )
2020-03-02 10:49:41 +01:00
- > setAllowedInputTypes ( sd : : DataType : : ANY )
2019-06-06 14:21:15 +02:00
- > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
CUSTOM_OP_IMPL ( log_matrix_determinant , 1 , 1 , false , 0 , 0 ) {
auto input = INPUT_VARIABLE ( 0 ) ;
auto output = OUTPUT_VARIABLE ( 0 ) ;
REQUIRE_TRUE ( input - > rankOf ( ) > = 2 , 0 , " log_matrix_determinant: The rank of input array should not less than 2, but %i is given " , input - > rankOf ( ) ) ;
2021-02-01 06:31:20 +01:00
REQUIRE_TRUE ( input - > sizeAt ( - 1 ) = = input - > sizeAt ( - 2 ) , 0 , " log_matrix_determinant: The last two dimensions should be equal, but %i and %i are given " , input - > sizeAt ( - 1 ) , input - > sizeAt ( - 2 ) ) ;
2019-06-06 14:21:15 +02:00
2019-07-12 10:51:51 +02:00
return helpers : : logAbsDeterminant ( block . launchContext ( ) , input , output ) ;
2019-06-06 14:21:15 +02:00
}
DECLARE_SHAPE_FN ( log_matrix_determinant ) {
auto inShape = inputShape - > at ( 0 ) ;
2020-05-09 07:06:14 +02:00
Nd4jLong const * determinantShape ;
2019-06-06 14:21:15 +02:00
int targetRank = shape : : rank ( inShape ) - 2 ; // last two dimensions will be reduced to scalar
if ( targetRank = = 0 ) { // scalar only
2020-06-06 14:26:55 +02:00
determinantShape = ConstantShapeHelper : : getInstance ( ) . scalarShapeInfo ( ArrayOptions : : dataType ( inShape ) ) ;
2019-06-06 14:21:15 +02:00
}
2019-07-12 10:51:51 +02:00
else if ( targetRank = = 1 ) { // vector
2020-06-06 14:26:55 +02:00
determinantShape = ConstantShapeHelper : : getInstance ( ) . vectorShapeInfo ( shape : : sizeAt ( inShape , 0 ) , ArrayOptions : : dataType ( inShape ) ) ;
2019-06-06 14:21:15 +02:00
}
else { // only two last dimensions are excluded
2020-06-06 14:26:55 +02:00
determinantShape = ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( ArrayOptions : : dataType ( inShape ) , shape : : order ( inShape ) , targetRank , shape : : shapeOf ( inShape ) ) ;
2019-06-06 14:21:15 +02:00
}
return SHAPELIST ( determinantShape ) ;
}
}
}
# endif
# if NOT_EXCLUDED(OP_logdet)
2020-03-02 10:49:41 +01:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
DECLARE_TYPES ( logdet ) {
getOpDescriptor ( )
2020-03-02 10:49:41 +01:00
- > setAllowedInputTypes ( sd : : DataType : : ANY )
2019-06-06 14:21:15 +02:00
- > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
CUSTOM_OP_IMPL ( logdet , 1 , 1 , false , 0 , 0 ) {
auto input = INPUT_VARIABLE ( 0 ) ;
2020-03-20 06:49:28 +01:00
auto output = OUTPUT_NULLIFIED ( 0 ) ;
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( input - > rankOf ( ) > = 2 , 0 , " logdet: The rank of input array should not less than 2, but %i is given " , input - > rankOf ( ) ) ;
REQUIRE_TRUE ( input - > sizeAt ( - 1 ) = = input - > sizeAt ( - 2 ) , 0 , " logdet: The last two dimmensions should be equal, but %i and %i are given " , input - > sizeAt ( - 1 ) , input - > sizeAt ( - 2 ) ) ;
REQUIRE_TRUE ( helpers : : checkCholeskyInput ( block . launchContext ( ) , input ) , 0 , " logdet: The input tensor should be positive-defined hermitian. " ) ;
return helpers : : logdetFunctor ( block . launchContext ( ) , input , output ) ;
}
DECLARE_SHAPE_FN ( logdet ) {
auto inShape = inputShape - > at ( 0 ) ;
2020-05-09 07:06:14 +02:00
Nd4jLong const * determinantShape ;
2019-06-06 14:21:15 +02:00
int targetRank = shape : : rank ( inShape ) - 2 ; // last two dimensions will be reduced to scalar
if ( targetRank = = 0 ) { // scalar only
2020-06-06 14:26:55 +02:00
determinantShape = ConstantShapeHelper : : getInstance ( ) . scalarShapeInfo ( ArrayOptions : : dataType ( inShape ) ) ;
2019-06-06 14:21:15 +02:00
}
2019-07-12 10:51:51 +02:00
else if ( targetRank = = 1 ) { // vector
2020-06-06 14:26:55 +02:00
determinantShape = ConstantShapeHelper : : getInstance ( ) . vectorShapeInfo ( shape : : sizeAt ( inShape , 0 ) , ArrayOptions : : dataType ( inShape ) ) ;
2019-06-06 14:21:15 +02:00
}
else { // only two last dimensions are excluded
2020-06-06 14:26:55 +02:00
determinantShape = ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( ArrayOptions : : dataType ( inShape ) , shape : : order ( inShape ) , targetRank , shape : : shapeOf ( inShape ) ) ;
2019-06-06 14:21:15 +02:00
}
return SHAPELIST ( determinantShape ) ;
}
}
}
# endif