Shyrma svd (#191)
* - add one additional test for svd * - provide float argument in eye op to be a type of output array Signed-off-by: Yurii <yurii@skymind.io> * - add cuda capability check to mmulHelper Signed-off-by: Yurii <yurii@skymind.io> * - make use another method for divice id evaluation Signed-off-by: Yurii <yurii@skymind.io> * Eye data type as T argument Signed-off-by: raver119 <raver119@gmail.com>master
parent
dec296da17
commit
70af8c2afc
|
@ -218,6 +218,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
||||||
|
|
||||||
const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC);
|
const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC);
|
||||||
|
|
||||||
|
const int deviceId = AffinityManager::currentDeviceId();
|
||||||
|
const int major = Environment::getInstance()->capabilities()[deviceId].first();
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({pC}, {pA, pB});
|
NDArray::prepareSpecialUse({pC}, {pA, pB});
|
||||||
|
|
||||||
// choose appropriate cuda gemm api depending on data types
|
// choose appropriate cuda gemm api depending on data types
|
||||||
|
@ -228,20 +231,18 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
||||||
float alphaF(alpha), betaF(beta);
|
float alphaF(alpha), betaF(beta);
|
||||||
status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc);
|
status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc);
|
||||||
}
|
}
|
||||||
#if __CUDA_ARCH__ >= 530
|
else if(ABC && aType == DataType::HALF && major >= 6) {
|
||||||
else if(ABC && aType == DataType::HALF) {
|
|
||||||
float16 alphaH(alpha), betaH(beta);
|
float16 alphaH(alpha), betaH(beta);
|
||||||
status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc);
|
status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc);
|
||||||
}
|
}
|
||||||
else if(AB && aType == DataType::INT8 && cType == DataType::FLOAT32) {
|
else if(AB && aType == DataType::INT8 && cType == DataType::FLOAT32 && major >= 6) {
|
||||||
float alphaF(alpha), betaF(beta);
|
float alphaF(alpha), betaF(beta);
|
||||||
status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_8I, lda, pB->getSpecialBuffer(), CUDA_R_8I, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc);
|
status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_8I, lda, pB->getSpecialBuffer(), CUDA_R_8I, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc);
|
||||||
}
|
}
|
||||||
else if(AB && aType == DataType::HALF && cType == DataType::FLOAT32) {
|
else if(AB && aType == DataType::HALF && cType == DataType::FLOAT32 && major >= 6) {
|
||||||
float alphaF(alpha), betaF(beta);
|
float alphaF(alpha), betaF(beta);
|
||||||
status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_16F, lda, pB->getSpecialBuffer(), CUDA_R_16F, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc);
|
status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_16F, lda, pB->getSpecialBuffer(), CUDA_R_16F, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
else {
|
else {
|
||||||
dim3 threadsPerBlock(N, M);
|
dim3 threadsPerBlock(N, M);
|
||||||
dim3 blocksPerGrid(1, 1);
|
dim3 blocksPerGrid(1, 1);
|
||||||
|
|
|
@ -27,7 +27,7 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(eye, -2, 1, false, 0, -2) {
|
CUSTOM_OP_IMPL(eye, -2, 1, false, -2, -2) {
|
||||||
|
|
||||||
helpers::eye(block.launchContext(), *OUTPUT_VARIABLE(0));
|
helpers::eye(block.launchContext(), *OUTPUT_VARIABLE(0));
|
||||||
|
|
||||||
|
@ -44,8 +44,7 @@ namespace ops {
|
||||||
|
|
||||||
std::vector<int> params;
|
std::vector<int> params;
|
||||||
|
|
||||||
// FIX ME: original has a dtype param - so should be used here instead. e.g. (DataType) INT_ARG(0);
|
nd4j::DataType dtype = block.getTArguments()->empty() ? nd4j::DataType::FLOAT32 : nd4j::DataTypeUtils::fromInt(T_ARG(0));
|
||||||
nd4j::DataType dtype = nd4j::DataType::FLOAT32;
|
|
||||||
|
|
||||||
if(block.width() == 0) {
|
if(block.width() == 0) {
|
||||||
params = *block.getIArguments();
|
params = *block.getIArguments();
|
||||||
|
@ -54,20 +53,20 @@ namespace ops {
|
||||||
for (int i = 0; i < block.width(); i++) {
|
for (int i = 0; i < block.width(); i++) {
|
||||||
auto input = INPUT_VARIABLE(i);
|
auto input = INPUT_VARIABLE(i);
|
||||||
REQUIRE_TRUE(input->rankOf() == 1, 0, "Inputs to eye should be 1D");
|
REQUIRE_TRUE(input->rankOf() == 1, 0, "Inputs to eye should be 1D");
|
||||||
for (int e = 0; e < input->lengthOf(); e++) {
|
|
||||||
|
for (int e = 0; e < input->lengthOf(); e++)
|
||||||
params.emplace_back(input->e<int>(e));
|
params.emplace_back(input->e<int>(e));
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
REQUIRE_TRUE(params.size() > 0, 0, "Size not provided for eye op.");
|
REQUIRE_TRUE(params.size() > 0, 0, "Size is not provided for eye op.");
|
||||||
|
|
||||||
const bool ordered = (params[0] == -99 || params[0] == -102); // -99 :'c', -102 : 'f'
|
const bool ordered = (params[0] == -99 || params[0] == -102); // -99 :'c', -102 : 'f'
|
||||||
if (!ordered)
|
if (!ordered)
|
||||||
params.insert(params.begin(), -99);
|
params.insert(params.begin(), -99);
|
||||||
|
|
||||||
REQUIRE_TRUE(params.size() > 1, 0, "Size not provided for eye op.");
|
REQUIRE_TRUE(params.size() > 1, 0, "Size is not provided for eye op.");
|
||||||
|
|
||||||
Nd4jLong* outShapeInfo(nullptr);
|
Nd4jLong* outShapeInfo(nullptr);
|
||||||
|
|
||||||
|
|
|
@ -115,6 +115,9 @@ namespace nd4j {
|
||||||
* Input array:
|
* Input array:
|
||||||
* provide some array - in any case operation simply neglects it
|
* provide some array - in any case operation simply neglects it
|
||||||
*
|
*
|
||||||
|
* Input float argument (if passed):
|
||||||
|
* TArgs[0] - type of elements of output array, default value is 5 (float)
|
||||||
|
*
|
||||||
* Input integer arguments:
|
* Input integer arguments:
|
||||||
* IArgs[0] - order of output identity matrix, 99 -> 'c'-order, 102 -> 'f'-order
|
* IArgs[0] - order of output identity matrix, 99 -> 'c'-order, 102 -> 'f'-order
|
||||||
* IArgs[1] - the number of rows in output inner-most 2D identity matrix
|
* IArgs[1] - the number of rows in output inner-most 2D identity matrix
|
||||||
|
@ -122,7 +125,7 @@ namespace nd4j {
|
||||||
* IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape
|
* IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_eye)
|
#if NOT_EXCLUDED(OP_eye)
|
||||||
DECLARE_CUSTOM_OP(eye, -2, 1, false, 0, 2);
|
DECLARE_CUSTOM_OP(eye, -2, 1, false, -2, 2);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if NOT_EXCLUDED(OP_gather_nd)
|
#if NOT_EXCLUDED(OP_gather_nd)
|
||||||
|
|
|
@ -272,6 +272,7 @@ TEST_F(CudaBasicsTests2, mmulMxM_12) {
|
||||||
NDArray exp('f', {M,N}, {-16., -22., -23., -25., 30., -12., -38., -70., 20., 16., 18., 18., 22., -8., -28., -52.}, nd4j::DataType::FLOAT32);
|
NDArray exp('f', {M,N}, {-16., -22., -23., -25., 30., -12., -38., -70., 20., 16., 18., 18., 22., -8., -28., -52.}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::MmulHelper::mmul(&a, &b, &c, 1., 0.);
|
nd4j::MmulHelper::mmul(&a, &b, &c, 1., 0.);
|
||||||
|
// c.printBuffer();
|
||||||
|
|
||||||
ASSERT_TRUE(c.equalsTo(&exp));
|
ASSERT_TRUE(c.equalsTo(&exp));
|
||||||
}
|
}
|
||||||
|
|
|
@ -2757,10 +2757,18 @@ TEST_F(DeclarableOpsTests3, svd_test10) {
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests3, svd_test11) {
|
TEST_F(DeclarableOpsTests3, svd_test11) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {3,3}, {1.,2.,3.,4.,5.,6.,7.,8.,9.});
|
NDArray x('c', {2,2,3,3}, {0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, 0.5056, 0.8925, 0.5461, 0.9234, 0.0856, 0.7938, 0.6591, 0.5555,
|
||||||
auto expS = NDArrayFactory::create<double>('c', {3});
|
0.1596, 0.3087, 0.1548, 0.4695, 0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, 0.5056, 0.8925, -0.5461, 0.9234,
|
||||||
auto expU = NDArrayFactory::create<double>('c', {3,3});
|
0.0856, -0.7938, 0.6591, 0.5555, 0.1500, 0.3087, 0.1548, 0.4695});
|
||||||
auto expV = NDArrayFactory::create<double>('c', {3,3});
|
NDArray expS('c', {2,2,3}, {1.89671, 0.37095, 0.05525,1.51296, 0.52741, 0.17622, 1.69095, 0.90438, 0.24688,1.33551, 0.87475, 0.21571});
|
||||||
|
NDArray expU('c', {2,2,3,3}, {6.9205e-01, 6.0147e-01, -3.9914e-01, 3.8423e-01, -7.7503e-01, -5.0170e-01, 6.1110e-01, -1.9384e-01, 7.6746e-01,
|
||||||
|
7.8967e-01, 4.5442e-01, -4.1222e-01, 4.9381e-01, -8.6948e-01, -1.2540e-02, 3.6412e-01, 1.9366e-01, 9.1100e-01,
|
||||||
|
7.1764e-01, 5.9844e-01, 3.5617e-01, 4.4477e-01, -3.1000e-04, -8.9564e-01, 5.3588e-01, -8.0116e-01, 2.6639e-01,
|
||||||
|
8.7050e-01, -4.2088e-01, -2.5513e-01, 4.8622e-01, 6.5499e-01, 5.7843e-01, 7.6340e-02, 6.2757e-01, -7.7481e-01});
|
||||||
|
NDArray expV('c', {2,2,3,3}, {0.49383, 0.51614, -0.69981, 0.72718, -0.68641, 0.00688, 0.4768 , 0.51228, 0.7143 , 0.77137, -0.17763,
|
||||||
|
-0.6111 , 0.26324, -0.7852 , 0.56051, 0.57939, 0.59322, 0.55892, 0.55149, 0.06737, 0.83146, 0.81413,
|
||||||
|
-0.26072, -0.51887, 0.18182, 0.96306, -0.19863, 0.85948, 0.2707 , -0.4336 , 0.26688, 0.48582, 0.83232,
|
||||||
|
-0.43596, 0.83108, -0.34531});
|
||||||
|
|
||||||
nd4j::ops::svd op;
|
nd4j::ops::svd op;
|
||||||
auto results = op.execute({&x}, {}, {0, 1, 16});
|
auto results = op.execute({&x}, {}, {0, 1, 16});
|
||||||
|
@ -2775,6 +2783,10 @@ TEST_F(DeclarableOpsTests3, svd_test11) {
|
||||||
ASSERT_TRUE(expU.isSameShape(u));
|
ASSERT_TRUE(expU.isSameShape(u));
|
||||||
ASSERT_TRUE(expV.isSameShape(v));
|
ASSERT_TRUE(expV.isSameShape(v));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expS.equalsTo(s));
|
||||||
|
ASSERT_TRUE(expU.equalsTo(u));
|
||||||
|
ASSERT_TRUE(expV.equalsTo(v));
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -639,10 +639,10 @@ TEST_F(DeclarableOpsTests5, eye_test2) {
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests5, eye_test3) {
|
TEST_F(DeclarableOpsTests5, eye_test3) {
|
||||||
|
|
||||||
auto expected = NDArrayFactory::create<float>('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0});
|
auto expected = NDArrayFactory::create<int>('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0});
|
||||||
|
|
||||||
nd4j::ops::eye op;
|
nd4j::ops::eye op;
|
||||||
auto results = op.execute({}, {}, {-99, 3, 4, 2});
|
auto results = op.execute({}, {9 /*int*/}, {-99, 3, 4, 2});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
// output->printIndexedBuffer("Output eye");
|
// output->printIndexedBuffer("Output eye");
|
||||||
|
|
||||||
|
@ -656,10 +656,10 @@ TEST_F(DeclarableOpsTests5, eye_test3) {
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests5, eye_test4) {
|
TEST_F(DeclarableOpsTests5, eye_test4) {
|
||||||
|
|
||||||
auto expected = NDArrayFactory::create<float>('c', {2, 2, 3, 4}, {1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.});
|
auto expected = NDArrayFactory::create<double>('c', {2, 2, 3, 4}, {1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.});
|
||||||
|
|
||||||
nd4j::ops::eye op;
|
nd4j::ops::eye op;
|
||||||
auto results = op.execute({}, {}, {-99, 3, 4, 2, 2});
|
auto results = op.execute({}, {6/*double*/}, {-99, 3, 4, 2, 2});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
|
@ -98,6 +98,9 @@ public class Eye extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void addArgs() {
|
protected void addArgs() {
|
||||||
|
iArguments.clear();
|
||||||
|
tArguments.clear();
|
||||||
|
|
||||||
addIArgument(numRows);
|
addIArgument(numRows);
|
||||||
addIArgument(numCols);
|
addIArgument(numCols);
|
||||||
if(batchDimension != null) {
|
if(batchDimension != null) {
|
||||||
|
@ -105,6 +108,8 @@ public class Eye extends DynamicCustomOp {
|
||||||
addIArgument(dim);
|
addIArgument(dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
addTArgument((double) dataType.toInt());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
Loading…
Reference in New Issue