Shyrma svd (#191)

* - add one additional test for svd

* - provide float argument in eye op to be a type of output array

Signed-off-by: Yurii <yurii@skymind.io>

* - add cuda capability check to mmulHelper

Signed-off-by: Yurii <yurii@skymind.io>

* - make use another method for divice id evaluation

Signed-off-by: Yurii <yurii@skymind.io>

* Eye data type as T argument

Signed-off-by: raver119 <raver119@gmail.com>
master
Yurii Shyrma 2019-08-28 18:27:08 +03:00 committed by raver119
parent dec296da17
commit 70af8c2afc
7 changed files with 52 additions and 31 deletions

View File

@ -218,6 +218,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC);
const int deviceId = AffinityManager::currentDeviceId();
const int major = Environment::getInstance()->capabilities()[deviceId].first();
NDArray::prepareSpecialUse({pC}, {pA, pB}); NDArray::prepareSpecialUse({pC}, {pA, pB});
// choose appropriate cuda gemm api depending on data types // choose appropriate cuda gemm api depending on data types
@ -228,20 +231,18 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
float alphaF(alpha), betaF(beta); float alphaF(alpha), betaF(beta);
status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc); status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc);
} }
#if __CUDA_ARCH__ >= 530 else if(ABC && aType == DataType::HALF && major >= 6) {
else if(ABC && aType == DataType::HALF) {
float16 alphaH(alpha), betaH(beta); float16 alphaH(alpha), betaH(beta);
status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc); status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc);
} }
else if(AB && aType == DataType::INT8 && cType == DataType::FLOAT32) { else if(AB && aType == DataType::INT8 && cType == DataType::FLOAT32 && major >= 6) {
float alphaF(alpha), betaF(beta); float alphaF(alpha), betaF(beta);
status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_8I, lda, pB->getSpecialBuffer(), CUDA_R_8I, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc); status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_8I, lda, pB->getSpecialBuffer(), CUDA_R_8I, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc);
} }
else if(AB && aType == DataType::HALF && cType == DataType::FLOAT32) { else if(AB && aType == DataType::HALF && cType == DataType::FLOAT32 && major >= 6) {
float alphaF(alpha), betaF(beta); float alphaF(alpha), betaF(beta);
status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_16F, lda, pB->getSpecialBuffer(), CUDA_R_16F, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc); status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_16F, lda, pB->getSpecialBuffer(), CUDA_R_16F, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc);
} }
#endif
else { else {
dim3 threadsPerBlock(N, M); dim3 threadsPerBlock(N, M);
dim3 blocksPerGrid(1, 1); dim3 blocksPerGrid(1, 1);

View File

@ -27,7 +27,7 @@ namespace nd4j {
namespace ops { namespace ops {
CUSTOM_OP_IMPL(eye, -2, 1, false, 0, -2) { CUSTOM_OP_IMPL(eye, -2, 1, false, -2, -2) {
helpers::eye(block.launchContext(), *OUTPUT_VARIABLE(0)); helpers::eye(block.launchContext(), *OUTPUT_VARIABLE(0));
@ -44,8 +44,7 @@ namespace ops {
std::vector<int> params; std::vector<int> params;
// FIX ME: original has a dtype param - so should be used here instead. e.g. (DataType) INT_ARG(0); nd4j::DataType dtype = block.getTArguments()->empty() ? nd4j::DataType::FLOAT32 : nd4j::DataTypeUtils::fromInt(T_ARG(0));
nd4j::DataType dtype = nd4j::DataType::FLOAT32;
if(block.width() == 0) { if(block.width() == 0) {
params = *block.getIArguments(); params = *block.getIArguments();
@ -54,20 +53,20 @@ namespace ops {
for (int i = 0; i < block.width(); i++) { for (int i = 0; i < block.width(); i++) {
auto input = INPUT_VARIABLE(i); auto input = INPUT_VARIABLE(i);
REQUIRE_TRUE(input->rankOf() == 1, 0, "Inputs to eye should be 1D"); REQUIRE_TRUE(input->rankOf() == 1, 0, "Inputs to eye should be 1D");
for (int e = 0; e < input->lengthOf(); e++) {
for (int e = 0; e < input->lengthOf(); e++)
params.emplace_back(input->e<int>(e)); params.emplace_back(input->e<int>(e));
}
} }
} }
REQUIRE_TRUE(params.size() > 0, 0, "Size not provided for eye op."); REQUIRE_TRUE(params.size() > 0, 0, "Size is not provided for eye op.");
const bool ordered = (params[0] == -99 || params[0] == -102); // -99 :'c', -102 : 'f' const bool ordered = (params[0] == -99 || params[0] == -102); // -99 :'c', -102 : 'f'
if (!ordered) if (!ordered)
params.insert(params.begin(), -99); params.insert(params.begin(), -99);
REQUIRE_TRUE(params.size() > 1, 0, "Size not provided for eye op."); REQUIRE_TRUE(params.size() > 1, 0, "Size is not provided for eye op.");
Nd4jLong* outShapeInfo(nullptr); Nd4jLong* outShapeInfo(nullptr);

View File

@ -115,6 +115,9 @@ namespace nd4j {
* Input array: * Input array:
* provide some array - in any case operation simply neglects it * provide some array - in any case operation simply neglects it
* *
* Input float argument (if passed):
* TArgs[0] - type of elements of output array, default value is 5 (float)
*
* Input integer arguments: * Input integer arguments:
* IArgs[0] - order of output identity matrix, 99 -> 'c'-order, 102 -> 'f'-order * IArgs[0] - order of output identity matrix, 99 -> 'c'-order, 102 -> 'f'-order
* IArgs[1] - the number of rows in output inner-most 2D identity matrix * IArgs[1] - the number of rows in output inner-most 2D identity matrix
@ -122,7 +125,7 @@ namespace nd4j {
* IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape * IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape
*/ */
#if NOT_EXCLUDED(OP_eye) #if NOT_EXCLUDED(OP_eye)
DECLARE_CUSTOM_OP(eye, -2, 1, false, 0, 2); DECLARE_CUSTOM_OP(eye, -2, 1, false, -2, 2);
#endif #endif
#if NOT_EXCLUDED(OP_gather_nd) #if NOT_EXCLUDED(OP_gather_nd)

View File

@ -272,6 +272,7 @@ TEST_F(CudaBasicsTests2, mmulMxM_12) {
NDArray exp('f', {M,N}, {-16., -22., -23., -25., 30., -12., -38., -70., 20., 16., 18., 18., 22., -8., -28., -52.}, nd4j::DataType::FLOAT32); NDArray exp('f', {M,N}, {-16., -22., -23., -25., 30., -12., -38., -70., 20., 16., 18., 18., 22., -8., -28., -52.}, nd4j::DataType::FLOAT32);
nd4j::MmulHelper::mmul(&a, &b, &c, 1., 0.); nd4j::MmulHelper::mmul(&a, &b, &c, 1., 0.);
// c.printBuffer();
ASSERT_TRUE(c.equalsTo(&exp)); ASSERT_TRUE(c.equalsTo(&exp));
} }

View File

@ -2757,10 +2757,18 @@ TEST_F(DeclarableOpsTests3, svd_test10) {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests3, svd_test11) { TEST_F(DeclarableOpsTests3, svd_test11) {
auto x = NDArrayFactory::create<double>('c', {3,3}, {1.,2.,3.,4.,5.,6.,7.,8.,9.}); NDArray x('c', {2,2,3,3}, {0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, 0.5056, 0.8925, 0.5461, 0.9234, 0.0856, 0.7938, 0.6591, 0.5555,
auto expS = NDArrayFactory::create<double>('c', {3}); 0.1596, 0.3087, 0.1548, 0.4695, 0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, 0.5056, 0.8925, -0.5461, 0.9234,
auto expU = NDArrayFactory::create<double>('c', {3,3}); 0.0856, -0.7938, 0.6591, 0.5555, 0.1500, 0.3087, 0.1548, 0.4695});
auto expV = NDArrayFactory::create<double>('c', {3,3}); NDArray expS('c', {2,2,3}, {1.89671, 0.37095, 0.05525,1.51296, 0.52741, 0.17622, 1.69095, 0.90438, 0.24688,1.33551, 0.87475, 0.21571});
NDArray expU('c', {2,2,3,3}, {6.9205e-01, 6.0147e-01, -3.9914e-01, 3.8423e-01, -7.7503e-01, -5.0170e-01, 6.1110e-01, -1.9384e-01, 7.6746e-01,
7.8967e-01, 4.5442e-01, -4.1222e-01, 4.9381e-01, -8.6948e-01, -1.2540e-02, 3.6412e-01, 1.9366e-01, 9.1100e-01,
7.1764e-01, 5.9844e-01, 3.5617e-01, 4.4477e-01, -3.1000e-04, -8.9564e-01, 5.3588e-01, -8.0116e-01, 2.6639e-01,
8.7050e-01, -4.2088e-01, -2.5513e-01, 4.8622e-01, 6.5499e-01, 5.7843e-01, 7.6340e-02, 6.2757e-01, -7.7481e-01});
NDArray expV('c', {2,2,3,3}, {0.49383, 0.51614, -0.69981, 0.72718, -0.68641, 0.00688, 0.4768 , 0.51228, 0.7143 , 0.77137, -0.17763,
-0.6111 , 0.26324, -0.7852 , 0.56051, 0.57939, 0.59322, 0.55892, 0.55149, 0.06737, 0.83146, 0.81413,
-0.26072, -0.51887, 0.18182, 0.96306, -0.19863, 0.85948, 0.2707 , -0.4336 , 0.26688, 0.48582, 0.83232,
-0.43596, 0.83108, -0.34531});
nd4j::ops::svd op; nd4j::ops::svd op;
auto results = op.execute({&x}, {}, {0, 1, 16}); auto results = op.execute({&x}, {}, {0, 1, 16});
@ -2775,6 +2783,10 @@ TEST_F(DeclarableOpsTests3, svd_test11) {
ASSERT_TRUE(expU.isSameShape(u)); ASSERT_TRUE(expU.isSameShape(u));
ASSERT_TRUE(expV.isSameShape(v)); ASSERT_TRUE(expV.isSameShape(v));
ASSERT_TRUE(expS.equalsTo(s));
ASSERT_TRUE(expU.equalsTo(u));
ASSERT_TRUE(expV.equalsTo(v));
delete results; delete results;
} }

View File

@ -639,10 +639,10 @@ TEST_F(DeclarableOpsTests5, eye_test2) {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, eye_test3) { TEST_F(DeclarableOpsTests5, eye_test3) {
auto expected = NDArrayFactory::create<float>('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}); auto expected = NDArrayFactory::create<int>('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0});
nd4j::ops::eye op; nd4j::ops::eye op;
auto results = op.execute({}, {}, {-99, 3, 4, 2}); auto results = op.execute({}, {9 /*int*/}, {-99, 3, 4, 2});
auto output = results->at(0); auto output = results->at(0);
// output->printIndexedBuffer("Output eye"); // output->printIndexedBuffer("Output eye");
@ -656,10 +656,10 @@ TEST_F(DeclarableOpsTests5, eye_test3) {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, eye_test4) { TEST_F(DeclarableOpsTests5, eye_test4) {
auto expected = NDArrayFactory::create<float>('c', {2, 2, 3, 4}, {1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.}); auto expected = NDArrayFactory::create<double>('c', {2, 2, 3, 4}, {1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.});
nd4j::ops::eye op; nd4j::ops::eye op;
auto results = op.execute({}, {}, {-99, 3, 4, 2, 2}); auto results = op.execute({}, {6/*double*/}, {-99, 3, 4, 2, 2});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());

View File

@ -98,6 +98,9 @@ public class Eye extends DynamicCustomOp {
} }
protected void addArgs() { protected void addArgs() {
iArguments.clear();
tArguments.clear();
addIArgument(numRows); addIArgument(numRows);
addIArgument(numCols); addIArgument(numCols);
if(batchDimension != null) { if(batchDimension != null) {
@ -105,6 +108,8 @@ public class Eye extends DynamicCustomOp {
addIArgument(dim); addIArgument(dim);
} }
} }
addTArgument((double) dataType.toInt());
} }
@Override @Override