2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright ( c ) 2015 - 2018 Skymind , Inc .
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author raver119@gmail.com
//
2020-03-02 10:49:41 +01:00
# include <system/op_boilerplate.h>
2019-06-06 14:21:15 +02:00
# if NOT_EXCLUDED(OP_batched_gemm)
# include <ops/declarable/headers/blas.h>
# include <ops/declarable/helpers/batched_gemm.h>
2020-03-02 10:49:41 +01:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
CUSTOM_OP_IMPL ( batched_gemm , - 1 , - 1 , false , 0 , 9 ) {
int transA = INT_ARG ( 0 ) ;
int transB = INT_ARG ( 1 ) ;
int M = INT_ARG ( 2 ) ;
int N = INT_ARG ( 3 ) ;
int K = INT_ARG ( 4 ) ;
int ldA = INT_ARG ( 5 ) ;
int ldB = INT_ARG ( 6 ) ;
int ldC = INT_ARG ( 7 ) ;
int batchSize = INT_ARG ( 8 ) ;
if ( transA = = 0 )
transA = 111 ;
if ( transB = = 0 )
transB = 111 ;
if ( transA = = 1 )
transA = 112 ;
if ( transB = = 1 )
transB = 112 ;
// basically A+B and 2 arrays of alpha and beta
int expectedWidth = batchSize * 2 + 2 ;
REQUIRE_TRUE ( ( transA = = 111 | | transA = = 112 ) & & ( transB = = 111 | | transB = = 112 ) , 0 , " BatchedGemm: valid values for transA and transB are: 0/1 or 111/112, for NoTrans/Trans respectively " )
REQUIRE_TRUE ( M > 0 & & N > 0 & & K > 0 & & ldA > 0 & & ldB > 0 & & ldC > 0 & & batchSize > 0 , 0 , " " ) ;
REQUIRE_TRUE ( block . width ( ) = = expectedWidth , 0 , " BatchedGemm: expected number of input arrays is %i, but got %i instead " , expectedWidth , block . width ( ) ) ;
auto alpha = INPUT_VARIABLE ( 0 ) ;
auto beta = INPUT_VARIABLE ( 1 ) ;
std : : vector < NDArray * > vA ( batchSize ) ;
std : : vector < NDArray * > vB ( batchSize ) ;
std : : vector < NDArray * > vC ( batchSize ) ;
auto firstType = INPUT_VARIABLE ( 0 ) - > dataType ( ) ;
for ( int e = 0 ; e < batchSize ; e + + ) {
vA [ e ] = INPUT_VARIABLE ( e + 2 ) ;
vB [ e ] = INPUT_VARIABLE ( e + 2 + batchSize ) ;
vC [ e ] = OUTPUT_VARIABLE ( e ) ;
REQUIRE_TRUE ( firstType = = vC [ e ] - > dataType ( ) , 0 , " BatchedGemm: all inputs and outputs must have same data type " ) ;
REQUIRE_TRUE ( vA [ e ] - > rankOf ( ) = = 2 , 0 , " BatchedGemm: batch %i, rank of A should be equal to 2 " , e ) ;
REQUIRE_TRUE ( vB [ e ] - > rankOf ( ) = = 2 , 0 , " BatchedGemm: batch %i, rank of B should be equal to 2 " , e ) ;
REQUIRE_TRUE ( vC [ e ] - > rankOf ( ) = = 2 , 0 , " BatchedGemm: batch %i, rank of C should be equal to 2 " , e ) ;
REQUIRE_TRUE ( M = = vA [ e ] - > sizeAt ( 0 ) , 0 , " BatchedGemm: batch %i, number of A.rows() should be equal to M " , e ) ;
REQUIRE_TRUE ( N = = vB [ e ] - > sizeAt ( 1 ) , 0 , " BatchedGemm: batch %i, number of B.columns() should be equal to N " , e ) ;
REQUIRE_TRUE ( K = = vA [ e ] - > sizeAt ( 1 ) & & K = = vB [ e ] - > sizeAt ( 0 ) , 0 , " BatchedGemm: batch %i, number of A.columns() and B.rows() should be equal to K " , e ) ;
} ;
REQUIRE_TRUE ( vA . size ( ) = = vB . size ( ) & & vA . size ( ) = = vC . size ( ) & & vA . size ( ) = = batchSize , 0 , " BatchedGemm: mismatched numbers of A, B, C for unknown reason " ) ;
2020-03-02 10:49:41 +01:00
sd : : ops : : helpers : : bgemm ( vA , vB , vC , alpha , beta , transA , transB , M , N , K , ldA , ldB , ldC ) ;
2019-06-06 14:21:15 +02:00
return Status : : OK ( ) ;
} ;
DECLARE_SHAPE_FN ( batched_gemm ) {
int transA = INT_ARG ( 0 ) ;
int transB = INT_ARG ( 1 ) ;
int M = INT_ARG ( 2 ) ;
int N = INT_ARG ( 3 ) ;
int K = INT_ARG ( 4 ) ;
int ldA = INT_ARG ( 5 ) ;
int ldB = INT_ARG ( 6 ) ;
int ldC = INT_ARG ( 7 ) ;
int batchSize = INT_ARG ( 8 ) ;
auto firstType = ArrayOptions : : dataType ( inputShape - > at ( 0 ) ) ;
for ( int e = 1 ; e < block . width ( ) ; e + + ) {
REQUIRE_TRUE ( firstType = = ArrayOptions : : dataType ( inputShape - > at ( 1 ) ) , 0 , " BatchedGemm: all inputs must have same data type " ) ;
}
auto shapeList = SHAPELIST ( ) ;
if ( ! ( M > 0 & & N > 0 & & K > 0 & & ldA > 0 & & ldB > 0 & & ldC > 0 & & batchSize > 0 ) ) {
2019-11-30 14:02:07 +01:00
shapeList - > push_back ( ConstantShapeHelper : : getInstance ( ) - > createShapeInfo ( ArrayOptions : : dataType ( inputShape - > at ( 0 ) ) , ' c ' , { 1 , 1 } ) ) ;
2019-06-06 14:21:15 +02:00
return shapeList ;
}
std : : vector < Nd4jLong > shape ( { M , N } ) ;
for ( int e = 0 ; e < batchSize ; e + + ) {
2019-11-30 14:02:07 +01:00
auto newShape = ConstantShapeHelper : : getInstance ( ) - > createShapeInfo ( ArrayOptions : : dataType ( inputShape - > at ( 0 ) ) , ' f ' , shape ) ;
2019-06-06 14:21:15 +02:00
shapeList - > push_back ( newShape ) ;
}
return shapeList ;
}
DECLARE_TYPES ( batched_gemm ) {
getOpDescriptor ( )
- > setAllowedInputTypes ( { ALL_FLOATS } )
// ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
- > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
}
}
# endif