2021-02-09 05:16:31 +01:00
/*
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License , Version 2.0 which is available at
* * https : //www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership .
* * Unless required by applicable law or agreed to in writing , software
* * distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* * License for the specific language governing permissions and limitations
* * under the License .
* *
* * SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
*/
2020-02-28 09:37:26 +01:00
//
// Created by GS <sgazeos@gmail.com> at 01/28/2020
//
2020-03-02 10:49:41 +01:00
# include <system/op_boilerplate.h>
2020-02-28 09:37:26 +01:00
# if NOT_EXCLUDED(OP_lstsq)
# include <ops/declarable/CustomOperations.h>
# include <ops/declarable/helpers/lstsq.h>
2020-03-02 10:49:41 +01:00
namespace sd {
2020-02-28 09:37:26 +01:00
namespace ops {
2020-03-02 10:49:41 +01:00
2020-02-28 09:37:26 +01:00
CUSTOM_OP_IMPL ( lstsq , 2 , 1 , false , 0 , 0 ) {
auto a = INPUT_VARIABLE ( 0 ) ;
auto b = INPUT_VARIABLE ( 1 ) ;
2020-03-20 06:49:28 +01:00
auto z = OUTPUT_NULLIFIED ( 0 ) ;
2020-02-28 09:37:26 +01:00
bool fastFlag = true ;
double l2_factor = 0. ;
if ( block . numB ( ) > 0 ) {
fastFlag = B_ARG ( 0 ) ;
}
if ( block . numT ( ) > 0 ) {
l2_factor = T_ARG ( 0 ) ;
}
REQUIRE_TRUE ( a - > rankOf ( ) > = 2 , 0 , " lstsq: The rank of input left tensor should not be less than 2, but %i is given " , a - > rankOf ( ) ) ;
REQUIRE_TRUE ( b - > rankOf ( ) > = 2 , 0 , " lstsq: The rank of input right tensor should not be less than 2, but %i is given " , b - > rankOf ( ) ) ;
// REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "lstsq: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2));
REQUIRE_TRUE ( a - > sizeAt ( - 2 ) = = b - > sizeAt ( - 2 ) , 0 , " lstsq: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given " , a - > sizeAt ( - 1 ) , b - > sizeAt ( - 2 ) ) ;
//REQUIRE_TRUE(l2_factor == 0., 0, "lstsq: Implementation of operation is not finished for factor difference from 0.");
if ( a - > isEmpty ( ) | | b - > isEmpty ( ) | | z - > isEmpty ( ) )
return Status : : OK ( ) ;
auto res = helpers : : leastSquaresSolveFunctor ( block . launchContext ( ) , a , b , l2_factor , fastFlag , z ) ;
return res ;
}
CUSTOM_OP_IMPL ( solve_ls , 2 , 1 , false , 0 , 0 ) {
auto a = INPUT_VARIABLE ( 0 ) ;
auto b = INPUT_VARIABLE ( 1 ) ;
2020-03-20 06:49:28 +01:00
auto z = OUTPUT_NULLIFIED ( 0 ) ;
2020-02-28 09:37:26 +01:00
bool fastFlag = true ;
double l2_factor = 0. ;
if ( block . numB ( ) > 0 ) {
fastFlag = B_ARG ( 0 ) ;
}
if ( block . numT ( ) > 0 ) {
l2_factor = T_ARG ( 0 ) ;
}
REQUIRE_TRUE ( a - > rankOf ( ) > = 2 , 0 , " lstsq: The rank of input left tensor should not be less than 2, but %i is given " , a - > rankOf ( ) ) ;
REQUIRE_TRUE ( b - > rankOf ( ) > = 2 , 0 , " lstsq: The rank of input right tensor should not be less than 2, but %i is given " , b - > rankOf ( ) ) ;
// REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "lstsq: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2));
REQUIRE_TRUE ( a - > sizeAt ( - 2 ) = = b - > sizeAt ( - 2 ) , 0 , " lstsq: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given " , a - > sizeAt ( - 1 ) , b - > sizeAt ( - 2 ) ) ;
//REQUIRE_TRUE(l2_factor == 0., 0, "lstsq: Implementation of operation is not finished for factor difference from 0.");
auto res = Status : : OK ( ) ;
if ( a - > isEmpty ( ) | | b - > isEmpty ( ) | | z - > isEmpty ( ) )
return res ;
res = helpers : : leastSquaresSolveFunctor ( block . launchContext ( ) , a , b , l2_factor , fastFlag , z ) ;
return res ;
}
DECLARE_SYN ( MatrixSolveLs , lstsq ) ;
DECLARE_SHAPE_FN ( lstsq ) {
auto in0 = inputShape - > at ( 0 ) ;
auto in1 = inputShape - > at ( 1 ) ;
auto shapeOf = ShapeUtils : : shapeAsVector ( in1 ) ;
auto rank = shapeOf . size ( ) ;
shapeOf [ rank - 2 ] = shape : : sizeAt ( in0 , - 1 ) ;
if ( shape : : isEmpty ( in0 ) | | shape : : isEmpty ( in1 ) ) {
shapeOf [ rank - 1 ] = 0 ; // set output shape to empty
}
2020-06-06 14:26:55 +02:00
auto resShape = ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( ArrayOptions : : dataType ( in0 ) , shape : : order ( in1 ) , shapeOf ) ; //ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace());
2020-02-28 09:37:26 +01:00
if ( shapeOf [ rank - 1 ] = = 0 ) {
2020-05-09 07:06:14 +02:00
// ArrayOptions::setPropertyBit(resShape, ARRAY_EMPTY);
2020-06-06 14:26:55 +02:00
resShape = ConstantShapeHelper : : getInstance ( ) . emptyShapeInfo ( ArrayOptions : : dataType ( in0 ) ) ;
2020-02-28 09:37:26 +01:00
}
return SHAPELIST ( resShape ) ;
}
DECLARE_TYPES ( lstsq ) {
getOpDescriptor ( )
- > setAllowedInputTypes ( { ALL_FLOATS } )
- > setAllowedOutputTypes ( { ALL_FLOATS } )
- > setSameMode ( false ) ;
}
DECLARE_SHAPE_FN ( solve_ls ) {
auto in0 = inputShape - > at ( 0 ) ;
auto in1 = inputShape - > at ( 1 ) ;
auto shapeOf = ShapeUtils : : shapeAsVector ( in1 ) ;
auto rank = shapeOf . size ( ) ;
shapeOf [ rank - 2 ] = shape : : sizeAt ( in0 , - 1 ) ;
if ( shape : : isEmpty ( in0 ) | | shape : : isEmpty ( in1 ) ) {
shapeOf [ rank - 1 ] = 0 ; // set output shape to empty
}
2020-06-06 14:26:55 +02:00
auto resShape = ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( ArrayOptions : : dataType ( in0 ) , shape : : order ( in1 ) , shapeOf ) ; //ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace());
2020-02-28 09:37:26 +01:00
if ( shapeOf [ rank - 1 ] = = 0 ) {
2020-06-06 14:26:55 +02:00
resShape = ConstantShapeHelper : : getInstance ( ) . emptyShapeInfo ( ArrayOptions : : dataType ( in1 ) ) ;
2020-05-09 07:06:14 +02:00
// ArrayOptions::setPropertyBit(resShape, ARRAY_EMPTY);
2020-02-28 09:37:26 +01:00
}
return SHAPELIST ( resShape ) ;
}
DECLARE_TYPES ( solve_ls ) {
getOpDescriptor ( )
- > setAllowedInputTypes ( { ALL_FLOATS } )
- > setAllowedOutputTypes ( { ALL_FLOATS } )
- > setSameMode ( false ) ;
}
}
}
# endif