2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright ( c ) 2015 - 2018 Skymind , Inc .
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author Yurii Shyrma (iuriish@yahoo.com), created on 29.08.2018
//
# include <op_boilerplate.h>
# if NOT_EXCLUDED(OP_sparse_softmax_cross_entropy_loss_with_logits)
# include <ops/declarable/CustomOperations.h>
# include <ops/declarable/generic/helpers/ScatterHelper.h>
namespace nd4j {
namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL ( sparse_softmax_cross_entropy_loss_with_logits , 2 , 1 , false , 0 , 0 ) {
auto labels = INPUT_VARIABLE ( 0 ) ;
auto logits = INPUT_VARIABLE ( 1 ) ;
auto output = OUTPUT_VARIABLE ( 0 ) ;
const int labelsRank = labels - > rankOf ( ) ;
const int logitsRank = logits - > rankOf ( ) ;
2019-07-12 10:51:51 +02:00
// input validation
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( labelsRank = = logitsRank - 1 , 0 , " SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead ! " , labelsRank , logitsRank ) ;
std : : vector < Nd4jLong > labelsShape = labels - > getShapeAsVector ( ) ; // this is correct
std : : vector < Nd4jLong > logitsShape = logits - > getShapeAsVector ( ) ;
logitsShape . pop_back ( ) ;
bool equalSoft = logitsShape = = labelsShape ;
2019-07-12 10:51:51 +02:00
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( equalSoft , 0 , " SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead ! " , ShapeUtils : : shapeAsString ( labelsShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( logitsShape ) . c_str ( ) ) ;
std : : vector < int > dimension = { - 1 } ;
2019-12-20 20:35:39 +01:00
auto maxAlongDim = logits - > reduceAlongDimension ( reduce : : Max , dimension , true ) ;
2019-06-06 14:21:15 +02:00
auto logitsExp = ( * logits - maxAlongDim ) . transform ( transform : : Exp , nullptr ) ;
2019-12-20 20:35:39 +01:00
auto logSoftMax = - ( ( logitsExp / logitsExp . reduceAlongDimension ( reduce : : Sum , dimension , true ) ) . transform ( transform : : Log ) ) ;
2019-06-06 14:21:15 +02:00
2019-07-12 10:51:51 +02:00
helpers : : scatterForLoss ( block . launchContext ( ) , * labels , logSoftMax , * output , false ) ;
2019-06-06 14:21:15 +02:00
return Status : : OK ( ) ;
}
//////////////////////////////////////////////////////////////////////////
DECLARE_TYPES ( sparse_softmax_cross_entropy_loss_with_logits ) {
2019-07-12 10:51:51 +02:00
getOpDescriptor ( ) - > setAllowedInputTypes ( 0 , { ALL_INTS } ) - > setAllowedInputTypes ( 1 , { ALL_FLOATS } ) - > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
2019-06-06 14:21:15 +02:00
}
//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN ( sparse_softmax_cross_entropy_loss_with_logits ) {
auto labelsShapeInfo = inputShape - > at ( 0 ) ;
auto logitsShapeInfo = inputShape - > at ( 1 ) ;
REQUIRE_TRUE ( labelsShapeInfo [ 0 ] = = logitsShapeInfo [ 0 ] - 1 , 0 , " SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead ! " , labelsShapeInfo [ 0 ] , logitsShapeInfo [ 0 ] ) ;
bool equalSoft = true ;
for ( int i = 1 ; i < labelsShapeInfo [ 0 ] ; + + i )
if ( labelsShapeInfo [ i ] ! = logitsShapeInfo [ i ] ) {
equalSoft = false ;
break ;
2019-07-12 10:51:51 +02:00
}
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( equalSoft , 0 , " SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead ! " , ShapeUtils : : shapeAsString ( labelsShapeInfo ) . c_str ( ) , ShapeUtils : : shapeAsString ( logitsShapeInfo ) . c_str ( ) ) ;
auto outShapeInfo = ShapeBuilders : : copyShapeInfoAndType ( labelsShapeInfo , logitsShapeInfo , false , block . getWorkspace ( ) ) ;
return SHAPELIST ( CONSTANT ( outShapeInfo ) ) ;
}
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL ( sparse_softmax_cross_entropy_loss_with_logits_grad , 2 , 1 , false , 0 , 0 ) {
2019-07-12 10:51:51 +02:00
2019-06-06 14:21:15 +02:00
auto labels = INPUT_VARIABLE ( 0 ) ;
auto logits = INPUT_VARIABLE ( 1 ) ;
auto dLdp = OUTPUT_VARIABLE ( 0 ) ; // dL/dlogits
const int labelsRank = labels - > rankOf ( ) ;
const int logitsRank = logits - > rankOf ( ) ;
2019-07-12 10:51:51 +02:00
// input validation
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( labelsRank = = logitsRank - 1 , 0 , " SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead ! " , labelsRank , logitsRank ) ;
std : : vector < Nd4jLong > labelsShape = labels - > getShapeAsVector ( ) ; // this is correct
std : : vector < Nd4jLong > logitsShape = logits - > getShapeAsVector ( ) ;
logitsShape . pop_back ( ) ;
bool equalSoft = logitsShape = = labelsShape ;
2019-07-12 10:51:51 +02:00
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( equalSoft , 0 , " SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead ! " , ShapeUtils : : shapeAsString ( labelsShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( logitsShape ) . c_str ( ) ) ;
std : : vector < int > dimension = { - 1 } ;
2019-12-20 20:35:39 +01:00
NDArray softmax = ( * logits - logits - > reduceAlongDimension ( reduce : : Max , dimension , true ) ) . transform ( transform : : Exp ) ;
softmax / = softmax . reduceAlongDimension ( reduce : : Sum , dimension , true ) ;
2019-06-06 14:21:15 +02:00
// dEdp = softmax - 1 (or 0)
dLdp - > assign ( softmax ) ;
2019-07-12 10:51:51 +02:00
// subtract unities at appropriate indexes of dLdp array
2019-06-06 14:21:15 +02:00
helpers : : scatterForLoss ( block . launchContext ( ) , * labels , * dLdp , * labels /*actually third array is unnecessary for gradient calculation*/ , true ) ;
return Status : : OK ( ) ;
}
//////////////////////////////////////////////////////////////////////////
DECLARE_TYPES ( sparse_softmax_cross_entropy_loss_with_logits_grad ) {
2019-07-12 10:51:51 +02:00
getOpDescriptor ( ) - > setAllowedInputTypes ( 0 , { ALL_INTS } ) - > setAllowedInputTypes ( 1 , { ALL_FLOATS } ) - > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
2019-06-06 14:21:15 +02:00
}
//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN ( sparse_softmax_cross_entropy_loss_with_logits_grad ) {
auto labelsShapeInfo = inputShape - > at ( 0 ) ;
auto logitsShapeInfo = inputShape - > at ( 1 ) ;
REQUIRE_TRUE ( labelsShapeInfo [ 0 ] = = logitsShapeInfo [ 0 ] - 1 , 0 , " SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead ! " , labelsShapeInfo [ 0 ] , logitsShapeInfo [ 0 ] ) ;
bool equalSoft = true ;
for ( int i = 1 ; i < labelsShapeInfo [ 0 ] ; + + i )
if ( labelsShapeInfo [ i ] ! = logitsShapeInfo [ i ] ) {
equalSoft = false ;
break ;
2019-07-12 10:51:51 +02:00
}
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( equalSoft , 0 , " SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead ! " , ShapeUtils : : shapeAsString ( labelsShapeInfo ) . c_str ( ) , ShapeUtils : : shapeAsString ( logitsShapeInfo ) . c_str ( ) ) ;
2019-07-12 10:51:51 +02:00
DataType outType = DataTypeUtils : : pickFloatingType ( ArrayOptions : : dataType ( logitsShapeInfo ) ) ;
Nd4jLong * dLdpShapeInfo = ShapeBuilders : : copyShapeInfoAndType ( logitsShapeInfo , outType , false , block . getWorkspace ( ) ) ;
2019-06-06 14:21:15 +02:00
return SHAPELIST ( CONSTANT ( dLdpShapeInfo ) ) ;
}
}
}
# endif