Shyrma svd (#191)

* - add one additional test for svd

* - provide float argument in eye op to be a type of output array

Signed-off-by: Yurii <yurii@skymind.io>

* - add cuda capability check to mmulHelper

Signed-off-by: Yurii <yurii@skymind.io>

* - make use another method for divice id evaluation

Signed-off-by: Yurii <yurii@skymind.io>

* Eye data type as T argument

Signed-off-by: raver119 <raver119@gmail.com>
master
Yurii Shyrma 2019-08-28 18:27:08 +03:00 committed by raver119
parent dec296da17
commit 70af8c2afc
7 changed files with 52 additions and 31 deletions

View File

@ -218,6 +218,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC);
const int deviceId = AffinityManager::currentDeviceId();
const int major = Environment::getInstance()->capabilities()[deviceId].first();
NDArray::prepareSpecialUse({pC}, {pA, pB});
// choose appropriate cuda gemm api depending on data types
@ -228,20 +231,18 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
float alphaF(alpha), betaF(beta);
status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc);
}
#if __CUDA_ARCH__ >= 530
else if(ABC && aType == DataType::HALF) {
else if(ABC && aType == DataType::HALF && major >= 6) {
float16 alphaH(alpha), betaH(beta);
status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc);
}
else if(AB && aType == DataType::INT8 && cType == DataType::FLOAT32) {
else if(AB && aType == DataType::INT8 && cType == DataType::FLOAT32 && major >= 6) {
float alphaF(alpha), betaF(beta);
status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_8I, lda, pB->getSpecialBuffer(), CUDA_R_8I, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc);
}
else if(AB && aType == DataType::HALF && cType == DataType::FLOAT32) {
else if(AB && aType == DataType::HALF && cType == DataType::FLOAT32 && major >= 6) {
float alphaF(alpha), betaF(beta);
status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_16F, lda, pB->getSpecialBuffer(), CUDA_R_16F, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc);
}
#endif
else {
dim3 threadsPerBlock(N, M);
dim3 blocksPerGrid(1, 1);

View File

@ -27,7 +27,7 @@ namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(eye, -2, 1, false, 0, -2) {
CUSTOM_OP_IMPL(eye, -2, 1, false, -2, -2) {
helpers::eye(block.launchContext(), *OUTPUT_VARIABLE(0));
@ -44,8 +44,7 @@ namespace ops {
std::vector<int> params;
// FIX ME: original has a dtype param - so should be used here instead. e.g. (DataType) INT_ARG(0);
nd4j::DataType dtype = nd4j::DataType::FLOAT32;
nd4j::DataType dtype = block.getTArguments()->empty() ? nd4j::DataType::FLOAT32 : nd4j::DataTypeUtils::fromInt(T_ARG(0));
if(block.width() == 0) {
params = *block.getIArguments();
@ -54,27 +53,27 @@ namespace ops {
for (int i = 0; i < block.width(); i++) {
auto input = INPUT_VARIABLE(i);
REQUIRE_TRUE(input->rankOf() == 1, 0, "Inputs to eye should be 1D");
for (int e = 0; e < input->lengthOf(); e++) {
for (int e = 0; e < input->lengthOf(); e++)
params.emplace_back(input->e<int>(e));
}
}
}
REQUIRE_TRUE(params.size() > 0, 0, "Size not provided for eye op.");
REQUIRE_TRUE(params.size() > 0, 0, "Size is not provided for eye op.");
const bool ordered = (params[0] == -99 || params[0] == -102); // -99 :'c', -102 : 'f'
if (!ordered)
params.insert(params.begin(), -99);
REQUIRE_TRUE(params.size() > 1, 0, "Size not provided for eye op.");
REQUIRE_TRUE(params.size() > 1, 0, "Size is not provided for eye op.");
Nd4jLong* outShapeInfo(nullptr);
const int size = params.size();
switch(size) {
case 2:
ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong);
outShapeInfo[0] = 2;
@ -99,7 +98,7 @@ namespace ops {
outShapeInfo[i] = params[i+2];
break;
}
shape::updateStrides(outShapeInfo, static_cast<char>(-params[0]));
auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo, dtype));
RELEASE(outShapeInfo, block.getWorkspace());

View File

@ -111,18 +111,21 @@ namespace nd4j {
/**
* creates identity 2D matrix or batch of identical 2D identity matrices
*
*
* Input array:
* provide some array - in any case operation simply neglects it
*
*
* Input float argument (if passed):
* TArgs[0] - type of elements of output array, default value is 5 (float)
*
* Input integer arguments:
* IArgs[0] - order of output identity matrix, 99 -> 'c'-order, 102 -> 'f'-order
* IArgs[1] - the number of rows in output inner-most 2D identity matrix
* IArgs[2] - optional, the number of columns in output inner-most 2D identity matrix, if this argument is not provided then it is taken to be equal to number of rows
* IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape
* IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape
*/
#if NOT_EXCLUDED(OP_eye)
DECLARE_CUSTOM_OP(eye, -2, 1, false, 0, 2);
DECLARE_CUSTOM_OP(eye, -2, 1, false, -2, 2);
#endif
#if NOT_EXCLUDED(OP_gather_nd)
@ -143,10 +146,10 @@ namespace nd4j {
/**
* clip a list of given tensors with given average norm when needed
*
*
* Input:
* a list of tensors (at least one)
*
*
* Input floating point argument:
* clip_norm - a value that used as threshold value and norm to be used
*
@ -182,12 +185,12 @@ namespace nd4j {
/**
* returns histogram (as 1D array) with fixed bins width
*
*
* Input arrays:
* - input array with elements to be binned into output histogram
* - input array with elements to be binned into output histogram
* - range array with first element being bottom limit and second element being top limit of histogram,
please note that input_value <= range[0] will be mapped to histogram[0], input_value >= range[1] will be mapped to histogram[-1]
*
*
* Input integer arguments:
* nbins (optional) - number of histogram bins, default value is 100
*/

View File

@ -272,6 +272,7 @@ TEST_F(CudaBasicsTests2, mmulMxM_12) {
NDArray exp('f', {M,N}, {-16., -22., -23., -25., 30., -12., -38., -70., 20., 16., 18., 18., 22., -8., -28., -52.}, nd4j::DataType::FLOAT32);
nd4j::MmulHelper::mmul(&a, &b, &c, 1., 0.);
// c.printBuffer();
ASSERT_TRUE(c.equalsTo(&exp));
}

View File

@ -2757,10 +2757,18 @@ TEST_F(DeclarableOpsTests3, svd_test10) {
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests3, svd_test11) {
auto x = NDArrayFactory::create<double>('c', {3,3}, {1.,2.,3.,4.,5.,6.,7.,8.,9.});
auto expS = NDArrayFactory::create<double>('c', {3});
auto expU = NDArrayFactory::create<double>('c', {3,3});
auto expV = NDArrayFactory::create<double>('c', {3,3});
NDArray x('c', {2,2,3,3}, {0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, 0.5056, 0.8925, 0.5461, 0.9234, 0.0856, 0.7938, 0.6591, 0.5555,
0.1596, 0.3087, 0.1548, 0.4695, 0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, 0.5056, 0.8925, -0.5461, 0.9234,
0.0856, -0.7938, 0.6591, 0.5555, 0.1500, 0.3087, 0.1548, 0.4695});
NDArray expS('c', {2,2,3}, {1.89671, 0.37095, 0.05525,1.51296, 0.52741, 0.17622, 1.69095, 0.90438, 0.24688,1.33551, 0.87475, 0.21571});
NDArray expU('c', {2,2,3,3}, {6.9205e-01, 6.0147e-01, -3.9914e-01, 3.8423e-01, -7.7503e-01, -5.0170e-01, 6.1110e-01, -1.9384e-01, 7.6746e-01,
7.8967e-01, 4.5442e-01, -4.1222e-01, 4.9381e-01, -8.6948e-01, -1.2540e-02, 3.6412e-01, 1.9366e-01, 9.1100e-01,
7.1764e-01, 5.9844e-01, 3.5617e-01, 4.4477e-01, -3.1000e-04, -8.9564e-01, 5.3588e-01, -8.0116e-01, 2.6639e-01,
8.7050e-01, -4.2088e-01, -2.5513e-01, 4.8622e-01, 6.5499e-01, 5.7843e-01, 7.6340e-02, 6.2757e-01, -7.7481e-01});
NDArray expV('c', {2,2,3,3}, {0.49383, 0.51614, -0.69981, 0.72718, -0.68641, 0.00688, 0.4768 , 0.51228, 0.7143 , 0.77137, -0.17763,
-0.6111 , 0.26324, -0.7852 , 0.56051, 0.57939, 0.59322, 0.55892, 0.55149, 0.06737, 0.83146, 0.81413,
-0.26072, -0.51887, 0.18182, 0.96306, -0.19863, 0.85948, 0.2707 , -0.4336 , 0.26688, 0.48582, 0.83232,
-0.43596, 0.83108, -0.34531});
nd4j::ops::svd op;
auto results = op.execute({&x}, {}, {0, 1, 16});
@ -2775,6 +2783,10 @@ TEST_F(DeclarableOpsTests3, svd_test11) {
ASSERT_TRUE(expU.isSameShape(u));
ASSERT_TRUE(expV.isSameShape(v));
ASSERT_TRUE(expS.equalsTo(s));
ASSERT_TRUE(expU.equalsTo(u));
ASSERT_TRUE(expV.equalsTo(v));
delete results;
}

View File

@ -639,10 +639,10 @@ TEST_F(DeclarableOpsTests5, eye_test2) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, eye_test3) {
auto expected = NDArrayFactory::create<float>('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0});
auto expected = NDArrayFactory::create<int>('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0});
nd4j::ops::eye op;
auto results = op.execute({}, {}, {-99, 3, 4, 2});
auto results = op.execute({}, {9 /*int*/}, {-99, 3, 4, 2});
auto output = results->at(0);
// output->printIndexedBuffer("Output eye");
@ -656,10 +656,10 @@ TEST_F(DeclarableOpsTests5, eye_test3) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, eye_test4) {
auto expected = NDArrayFactory::create<float>('c', {2, 2, 3, 4}, {1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.});
auto expected = NDArrayFactory::create<double>('c', {2, 2, 3, 4}, {1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.});
nd4j::ops::eye op;
auto results = op.execute({}, {}, {-99, 3, 4, 2, 2});
auto results = op.execute({}, {6/*double*/}, {-99, 3, 4, 2, 2});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());

View File

@ -98,6 +98,9 @@ public class Eye extends DynamicCustomOp {
}
protected void addArgs() {
iArguments.clear();
tArguments.clear();
addIArgument(numRows);
addIArgument(numCols);
if(batchDimension != null) {
@ -105,6 +108,8 @@ public class Eye extends DynamicCustomOp {
addIArgument(dim);
}
}
addTArgument((double) dataType.toInt());
}
@Override