Merge pull request #9236 from quickwritereader/qwr_tensormmul

fix:  tensormmul_bp shape mismach failure and wrong rank assumptions
master
Adam Gibson 2021-03-18 10:59:32 +09:00 committed by GitHub
commit 7249b6f14c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 121 additions and 2 deletions

View File

@ -141,10 +141,10 @@ CUSTOM_OP_IMPL(tensormmul_bp, 3, 2, false, 0, -1) {
std::vector<int> axesA = ShapeUtils::evalDimsToExclude(Arank, axes0);
std::vector<int> axesB = ShapeUtils::evalDimsToExclude(Brank, axes1);
// rank always have to be divided by 2
std::vector<int> axesAdLdC, axesBdLdC;
if (dLdCrank > 1) {
axesAdLdC.resize(dLdCrank / 2);
axesAdLdC.resize(axesA.size());
std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0);
axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC);
}

View File

@ -1803,6 +1803,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP12) {
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP13) {
@ -1925,6 +1926,124 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP17) {
ASSERT_TRUE(isGradCorrect);
}
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP18a) {
NDArray A('c', {2, 1, 1, 1}, { -1.0034018f, -1.5826055f } , sd::DataType::FLOAT32);
NDArray B('c', {2, 3, 2, 1}, { -0.23549192f, 1.0996383f, 0.09883034f, -1.0160788f, -1.1878633f, 0.32861f, -0.11048671f, -2.555923f, -0.17190187f, 0.030083546f, 0.62453437f, 1.4041749f } , sd::DataType::FLOAT32);
NDArray dLdC('c', {1, 1, 1, 2, 3, 1}, { -3.0080013f, 3.0177708f, 1.3436884f, 8.311761f, 0.24975249f, -5.697828f } , sd::DataType::FLOAT32);
NDArray dA('c', {2, 1, 1, 1}, { -5.109272f, -35.16991f } , sd::DataType::FLOAT32);
NDArray dB('c', {2, 3, 2, 1}, { 3.0182338f, 4.7604795f, -3.0280366f, -4.7759404f, -1.3482592f, -2.1265285f, -8.340035f, -13.154239f, -0.2506021f, -0.39525965f, 5.7172103f, 9.017413f } , sd::DataType::FLOAT32);
sd::ops::tensormmul_bp op;
auto results = op.evaluate({ &A, &B, &dLdC }, {}, { 1, 0 , 1, 2 } );
ASSERT_EQ(ND4J_STATUS_OK, results.status());
auto* dLdA = results.at(0);
auto* dLdB = results.at(1);
dLdA->printIndexedBuffer("dLdA");
dA.printIndexedBuffer("dA");
ASSERT_TRUE(dA.isSameShape(*dLdA));
ASSERT_TRUE(dA.equalsTo(*dLdA));
ASSERT_TRUE(dB.isSameShape(*dLdB));
ASSERT_TRUE(dB.equalsTo(*dLdB));
}
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP18b) {
NDArray A('c', {2, 1}, { -1.0034018f, -1.5826055f } , sd::DataType::FLOAT32);
NDArray B('c', {2, 3, 2, 1}, { -0.23549192f, 1.0996383f, 0.09883034f, -1.0160788f, -1.1878633f, 0.32861f, -0.11048671f, -2.555923f, -0.17190187f, 0.030083546f, 0.62453437f, 1.4041749f } , sd::DataType::FLOAT32);
NDArray dLdC('c', {1, 2, 3, 1}, { -3.0080013f, 3.0177708f, 1.3436884f, 8.311761f, 0.24975249f, -5.697828f } , sd::DataType::FLOAT32);
NDArray dA('c', {2, 1}, { -5.109272f, -35.16991f } , sd::DataType::FLOAT32);
NDArray dB('c', {2, 3, 2, 1}, { 3.0182338f, 4.7604795f, -3.0280366f, -4.7759404f, -1.3482592f, -2.1265285f, -8.340035f, -13.154239f, -0.2506021f, -0.39525965f, 5.7172103f, 9.017413f } , sd::DataType::FLOAT32);
sd::ops::tensormmul_bp op;
auto results = op.evaluate({ &A, &B, &dLdC }, {}, { 1, -2 , 1, -2 } );
ASSERT_EQ(ND4J_STATUS_OK, results.status());
auto* dLdA = results.at(0);
auto* dLdB = results.at(1);
ASSERT_TRUE(dA.isSameShape(*dLdA));
ASSERT_TRUE(dA.equalsTo(*dLdA));
ASSERT_TRUE(dB.isSameShape(*dLdB));
ASSERT_TRUE(dB.equalsTo(*dLdB));
}
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP19) {
NDArray A('c', {1, 4, 4}, { 1.0946114f, 0.21210167f, -0.13907638f, -1.0493592f, 1.5186632f, -0.6232642f, -1.2645237f, -0.5646886f, 0.75553244f, 0.64443725f, 1.2988646f, -1.3643093f, 0.44569862f, 0.42895862f, -0.25177997f, 0.6511963f } , sd::DataType::FLOAT32);
NDArray B('c', {4, 4}, { -0.20631117f, -0.3351916f, -0.49059778f, 0.976343f, 0.82378817f, -0.2437844f, -0.030423656f, 1.0651429f, 1.2425208f, 1.040625f, -2.1183474f, -0.36273366f, 2.1668184f, -0.794434f, 0.6742869f, 1.4766127f } , sd::DataType::FLOAT32);
NDArray dLdC('c', {1, 4, 4}, { 1.4648695f, -0.15249155f, -0.94158435f, 1.5351883f, 0.5865028f, 0.2053428f, -0.58049774f, -0.028567377f, 0.027649105f, 0.953271f, -1.4072582f, -1.162804f, -0.26642397f, -0.72383404f, 1.9305837f, -0.08478144f } , sd::DataType::FLOAT32);
NDArray dA('c', {1, 4, 4}, { 1.7097045f, 0.06706808f, -0.7701306f, -0.7323265f, 2.9077587f, 0.42032722f, -1.405354f, -0.1920572f, 3.0991826f, 2.182485f, 4.429202f, -5.143171f, 4.9272313f, 0.67410874f, -3.3633072f, 1.1743239f } , sd::DataType::FLOAT32);
NDArray dB('c', {4, 4}, { 2.00359f, 0.503619f, -2.983953f, 1.9250602f, 1.974581f, -1.1562591f, -0.37881336f, 3.8675072f, 1.8841178f, 2.2428217f, -5.547243f, -0.25318217f, 0.72402f, -0.6912543f, 0.94283605f, 0.9095385f } , sd::DataType::FLOAT32);
sd::ops::tensormmul_bp op;
auto results = op.evaluate({ &A, &B, &dLdC }, {}, { 1, -2 , 1, -2 } );
ASSERT_EQ(ND4J_STATUS_OK, results.status());
auto* dLdA = results.at(0);
auto* dLdB = results.at(1);
ASSERT_TRUE(dA.isSameShape(*dLdA));
ASSERT_TRUE(dA.equalsTo(*dLdA));
ASSERT_TRUE(dB.isSameShape(*dLdB));
ASSERT_TRUE(dB.equalsTo(*dLdB));
}
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP20) {
NDArray A('c', {2, 3, 2, 2}, { -1.104618f, 0.054949813f, -1.5429245f, 1.2395059f, 0.21299009f, -1.9405359f, 0.2876496f, -0.90021694f, 0.21733008f, -0.5207867f, -1.4787138f, 1.0963368f, 0.4180722f, -1.1413289f, -0.1855793f, -0.33207336f, -1.030655f, -0.78496057f, -0.20828249f, -0.8358334f, 0.21428165f, -0.46897757f, -0.45928282f, -0.6878806f } , sd::DataType::FLOAT32);
NDArray B('c', {2, 2}, { 0.08809058f, 1.4765077f, -0.42534503f, -0.20465784f } , sd::DataType::FLOAT32);
NDArray dLdC('c', {2, 3, 2, 2}, { 0.55896884f, -1.3152053f, -0.52237713f, -0.17254078f, -0.10358791f, 0.25561175f, 0.21195985f, -2.6809795f, 0.64810824f, 0.6235199f, -0.51219785f, -0.9933195f, 0.115763456f, 0.65526706f, 0.04070542f, -1.6172194f, -0.0021990836f, -1.4791434f, 0.28636992f, -0.98794043f, 0.21422987f, 0.4103843f, 0.25127408f, -0.5516688f } , sd::DataType::FLOAT32);
NDArray dA('c', {2, 3, 2, 2}, { -1.8926709f, -0.3007743f, 0.031412438f, 0.25750235f, 0.3682876f, -3.939815f, -0.008252345f, 0.4585274f, 0.97772413f, -1.5117637f, -0.40327787f, 0.42115146f, 0.9777045f, -2.384251f, -0.18334495f, 0.3136628f, -2.1841605f, -1.4334751f, 0.30365366f, 0.08038373f, 0.6248072f, -0.79240835f, -0.17510997f, 0.0060251653f } , sd::DataType::FLOAT32);
NDArray dB('c', {2, 2}, { -0.96445096f, 12.119483f, -3.7955897f, 4.0316315f } , sd::DataType::FLOAT32);
sd::ops::tensormmul_bp op;
auto results = op.evaluate({ &A, &B, &dLdC }, {}, { 1, -2 , 1, -2 } );
ASSERT_EQ(ND4J_STATUS_OK, results.status());
auto* dLdA = results.at(0);
auto* dLdB = results.at(1);
ASSERT_TRUE(dA.isSameShape(*dLdA));
ASSERT_TRUE(dA.equalsTo(*dLdA));
ASSERT_TRUE(dB.isSameShape(*dLdB));
ASSERT_TRUE(dB.equalsTo(*dLdB));
}
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP21) {
NDArray A('c', {2, 3, 2, 2}, { -1.104618f, 0.054949813f, -1.5429245f, 1.2395059f, 0.21299009f, -1.9405359f, 0.2876496f, -0.90021694f, 0.21733008f, -0.5207867f, -1.4787138f, 1.0963368f, 0.4180722f, -1.1413289f, -0.1855793f, -0.33207336f, -1.030655f, -0.78496057f, -0.20828249f, -0.8358334f, 0.21428165f, -0.46897757f, -0.45928282f, -0.6878806f } , sd::DataType::FLOAT32);
NDArray B('c', {2, 2, 2}, { 1.5363792f, -0.20325208f, 0.5527184f, 2.2368906f, -0.4360202f, 0.14528029f, 0.40532714f, 0.48687512f } , sd::DataType::FLOAT32);
NDArray dLdC('c', {2, 3, 2, 2, 2}, { -2.5499148f, -3.2268374f, -0.14375341f, -0.91169083f, 0.7695215f, 2.7614703f, 0.4784462f, 0.6114677f, 0.48622277f, 0.60015f, 0.023724213f, 0.17099269f, -3.4789655f, -1.6192687f, 0.4812305f, -0.72021484f, -0.48341087f, -3.3518937f, -0.6941231f, -0.6883752f, -0.19416028f, 2.5582364f, 0.6714486f, 0.4581191f, 0.5397443f, -0.5000946f, -0.25750825f, -0.029616296f, -1.9370571f, -0.5108343f, 0.3630441f, -0.32749087f, -1.6985986f, -0.25642234f, 0.3649639f, -0.25114143f, -1.6679776f, -1.710123f, 0.0034727156f, -0.5209858f, 0.075363815f, -1.0709187f, -0.2795909f, -0.19248249f, -1.1007316f, -1.4433929f, -0.07433298f, -0.40304515f } , sd::DataType::FLOAT32);
NDArray dA('c', {2, 3, 2, 2}, { -3.3315463f, 0.5012243f, -9.129614f, 7.0940714f, 0.6395385f, -5.3303494f, 1.7040823f, -5.700614f, 0.14122126f, -1.0444801f, -8.381509f, 6.1103816f, 1.0388733f, -3.078099f, -0.93912476f, -2.2256231f, -2.7531908f, -2.2922633f, -1.4867802f, -4.999527f, 0.42739722f, -1.4239123f, -2.5609136f, -4.063468f } , sd::DataType::FLOAT32);
NDArray dB('c', {2, 2, 2}, { 15.738445f, 7.3535337f, 11.675377f, 21.023096f, -2.1701825f, 3.263466f, 2.1788492f, 5.34995f } , sd::DataType::FLOAT32);
sd::ops::tensormmul_bp op;
auto results = op.evaluate({ &A, &B, &dLdC }, {}, { 1, -2 , 1, -2 } );
ASSERT_EQ(ND4J_STATUS_OK, results.status());
auto* dLdA = results.at(0);
auto* dLdB = results.at(1);
ASSERT_TRUE(dA.isSameShape(*dLdA));
ASSERT_TRUE(dA.equalsTo(*dLdA));
ASSERT_TRUE(dB.isSameShape(*dLdB));
ASSERT_TRUE(dB.equalsTo(*dLdB));
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, gru_1) {