2021-02-01 13:31:45 +01:00
/* ******************************************************************************
*
2019-06-06 14:21:15 +02:00
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
2021-02-01 13:31:45 +01:00
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership .
2019-06-06 14:21:15 +02:00
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author Yurii Shyrma (iuriish@yahoo.com), created on 06.12.2017.
//
2020-03-02 10:49:41 +01:00
# include <system/op_boilerplate.h>
2019-06-06 14:21:15 +02:00
# if NOT_EXCLUDED(OP_diag_part)
# include <ops/declarable/CustomOperations.h>
# include <ops/declarable/helpers/diag.h>
2020-03-02 10:49:41 +01:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
CUSTOM_OP_IMPL ( diag_part , 1 , 1 , false , 0 , 0 ) {
auto input = INPUT_VARIABLE ( 0 ) ;
auto output = OUTPUT_VARIABLE ( 0 ) ;
const int inRank = input - > rankOf ( ) ;
// input validation
REQUIRE_TRUE ( inRank = = 2 | | inRank = = 4 | | inRank = = 6 , 0 , " DIAG_PART op: input array must have rank among following three possible values: 2, 4, 6, but got %i instead ! " , inRank ) ;
for ( int i = 0 ; i < inRank - 1 ; + + i )
REQUIRE_TRUE ( input - > sizeAt ( i ) = = input - > sizeAt ( i + 1 ) , 0 , " DIAG_PART op: wrong shape of input array %s ! All dimensions must be equal ! " , ShapeUtils : : shapeAsString ( input ) . c_str ( ) ) ;
helpers : : diagPartFunctor ( block . launchContext ( ) , input , output ) ;
return Status : : OK ( ) ;
}
DECLARE_SYN ( DiagPart , diag_part ) ;
DECLARE_TYPES ( diag_part ) {
getOpDescriptor ( )
2020-03-02 10:49:41 +01:00
- > setAllowedInputTypes ( sd : : DataType : : ANY )
2019-06-06 14:21:15 +02:00
- > setSameMode ( true ) ;
}
DECLARE_SHAPE_FN ( diag_part ) {
auto inputShapeInfo = inputShape - > at ( 0 ) ;
const int inRank = inputShapeInfo [ 0 ] ;
// input validation
REQUIRE_TRUE ( inRank = = 2 | | inRank = = 4 | | inRank = = 6 , 0 , " DIAG_PART op: input array must have rank among following three possible values: 2, 4, 6, but got %i instead ! " , inRank ) ;
for ( int i = 1 ; i < inRank ; + + i )
REQUIRE_TRUE ( inputShapeInfo [ i ] = = inputShapeInfo [ i + 1 ] , 0 , " DIAG_PART op: wrong shape of input array %s ! All dimensions must be equal ! " , ShapeUtils : : shapeAsString ( inputShapeInfo ) . c_str ( ) ) ;
Nd4jLong * outShapeInfo = nullptr ;
int outRank = inRank / 2 ;
ALLOCATE ( outShapeInfo , block . getWorkspace ( ) , shape : : shapeInfoLength ( outRank ) , Nd4jLong ) ;
outShapeInfo [ 0 ] = outRank ;
for ( int i = 1 ; i < = outRank ; + + i )
outShapeInfo [ i ] = inputShapeInfo [ i ] ;
ShapeUtils : : updateStridesAndType ( outShapeInfo , inputShapeInfo , shape : : order ( inputShapeInfo ) ) ;
2020-06-06 14:26:55 +02:00
return SHAPELIST ( ConstantShapeHelper : : getInstance ( ) . createFromExisting ( outShapeInfo , block . workspace ( ) ) ) ;
2019-06-06 14:21:15 +02:00
}
}
}
# endif