diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java index bc892905c..44aa9a0b3 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java @@ -23,6 +23,11 @@ import org.junit.Test; public class TestDataSets extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 180000L; + } + @Test public void testTinyImageNetExists() throws Exception { //Simple sanity check on extracting diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java index 91ec8c98e..721786eef 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java @@ -44,6 +44,11 @@ import static org.junit.Assert.*; public class TestCheckpointListener extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 90000L; + } + @Rule public TemporaryFolder tempDir = new TemporaryFolder(); @@ -57,7 +62,7 @@ public class TestCheckpointListener extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - DataSetIterator iter = new IrisDataSetIterator(75,150); + DataSetIterator iter = new IrisDataSetIterator(25,50); return new Pair<>(net, iter); } @@ -178,13 +183,13 @@ public class TestCheckpointListener extends BaseDL4JTest { CheckpointListener l = new CheckpointListener.Builder(f) .keepLast(3) - .saveEvery(3, TimeUnit.SECONDS) + .saveEvery(4, TimeUnit.SECONDS) .build(); net.setListeners(l); for(int i=0; i<5; i++ ){ //10 iterations total net.fit(iter); - Thread.sleep(4000); + Thread.sleep(5000); } //Expect models saved at iterations: 2, 4, 6, 8 (iterations 0 and 1 shoud happen before first 3 seconds is up) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java index a66914cd7..05bb8b5eb 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java @@ -54,6 +54,11 @@ import static org.junit.Assert.*; @Slf4j public class RegressionTest100a extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections + } + @Override public DataType getDataType(){ return DataType.FLOAT; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java index b1bef2bfc..a28c1b845 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java @@ -52,6 +52,11 @@ import static org.junit.Assert.*; public class RegressionTest100b3 extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections + } + @Override public DataType getDataType(){ return DataType.FLOAT; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java index a4883ea07..ec8531eb2 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java @@ -69,6 +69,11 @@ import org.nd4j.resources.Resources; public class RegressionTest100b4 extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections + } + @Override public DataType getDataType() { return DataType.FLOAT; @@ -123,7 +128,8 @@ public class RegressionTest100b4 extends BaseDL4JTest { assertEquals(dtype, net.getLayerWiseConfigurations().getDataType()); assertEquals(dtype, net.params().dataType()); - assertEquals("Test for dtype: " + dtypeName, outExp, outAct); + boolean eq = outExp.equalsWithEps(outAct, 0.01); + assertTrue("Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct, eq); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java index 637f5860f..22ac01c14 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java @@ -56,6 +56,11 @@ public class RegressionTest100b6 extends BaseDL4JTest { return DataType.FLOAT; } + @Override + public long getTimeoutMilliseconds() { + return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections + } + @Test public void testCustomLayer() throws Exception { @@ -106,7 +111,8 @@ public class RegressionTest100b6 extends BaseDL4JTest { assertEquals(dtype, net.getLayerWiseConfigurations().getDataType()); assertEquals(dtype, net.params().dataType()); boolean eq = outExp.equalsWithEps(outAct, 0.01); - assertTrue(outExp + " vs " + outAct, eq); } + assertTrue(outExp + " vs " + outAct, eq); + } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index b17c215cb..7538d39bc 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -96,7 +96,12 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { } }; - @Test(expected = IllegalStateException.class) + @Override + public long getTimeoutMilliseconds() { + return 90000L; + } + + @Test(expected = IllegalStateException.class) public void fileNotFoundEndToEnd() throws Exception { String modelPath = "modelimport/keras/examples/foo/bar.h5"; importEndModelTest(modelPath, null, true, true, false, false); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java index 87d88efcb..cc387446a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java @@ -72,7 +72,7 @@ public class SimpleRnn extends BaseRecurrentLayer + + testresources + + true + + + test-nd4j-native diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java index c57b0fa30..8e109689a 100644 --- a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java @@ -19,6 +19,7 @@ import org.nd4j.remote.clients.serde.BinarySerializer; import org.nd4j.remote.clients.serde.JsonDeserializer; import org.nd4j.remote.clients.serde.JsonSerializer; import org.nd4j.remote.clients.serde.impl.IntegerSerde; +import org.nd4j.resources.Resources; import org.nd4j.shade.jackson.databind.ObjectMapper; import javax.imageio.ImageIO; @@ -65,7 +66,7 @@ public class BinaryModelServerTest extends BaseDL4JTest { @Test public void testMlnMnist_ImageInput() throws Exception { - val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile(); + val modelFile = Resources.asFile("models/mnist/mnist-model.zip"); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile); val server = new JsonModelServer.Builder(net) @@ -129,7 +130,7 @@ public class BinaryModelServerTest extends BaseDL4JTest { @Test public void testMlnMnist_ImageInput_Async() throws Exception { - val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile(); + val modelFile = Resources.asFile("models/mnist/mnist-model.zip"); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile); val server = new JsonModelServer.Builder(net) @@ -198,7 +199,7 @@ public class BinaryModelServerTest extends BaseDL4JTest { @Test public void testBinaryIn_BinaryOut() throws Exception { - val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile(); + val modelFile = Resources.asFile("models/mnist/mnist-model.zip"); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile); val server = new JsonModelServer.Builder(net) diff --git a/deeplearning4j/deeplearning4j-remote/pom.xml b/deeplearning4j/deeplearning4j-remote/pom.xml index 4ef2e06dd..4329a1554 100644 --- a/deeplearning4j/deeplearning4j-remote/pom.xml +++ b/deeplearning4j/deeplearning4j-remote/pom.xml @@ -20,6 +20,9 @@ deeplearning4j-remote + + testresources + test-nd4j-native diff --git a/libnd4j/include/helpers/MmulHelper.h b/libnd4j/include/helpers/MmulHelper.h index 6e38be5c1..517ca9888 100644 --- a/libnd4j/include/helpers/MmulHelper.h +++ b/libnd4j/include/helpers/MmulHelper.h @@ -60,7 +60,7 @@ namespace sd { static sd::NDArray* tensorDot(const sd::NDArray* a, const sd::NDArray* b, const std::vector>& modifA, const std::vector>& modifB); #endif - static void matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY); + static void matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY, double alpha = 1.0, double beta = 0.0); }; } diff --git a/libnd4j/include/helpers/impl/MmulHelper.cpp b/libnd4j/include/helpers/impl/MmulHelper.cpp index bc525622a..f5b9bc829 100644 --- a/libnd4j/include/helpers/impl/MmulHelper.cpp +++ b/libnd4j/include/helpers/impl/MmulHelper.cpp @@ -239,7 +239,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND ////////////////////////////////////////////////////////////////////////// - void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY) { + void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY, double alpha, double beta) { int xRank = x->rankOf(); int yRank = y->rankOf(); @@ -276,7 +276,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND zT = new NDArray(z->reshape(z->ordering(), {1, z->lengthOf()})); } - mmul(xT, yT, zT, 1., 0.); + mmul(xT, yT, zT, alpha, beta); } else { // rest cases - batched mmul @@ -292,7 +292,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND auto xSubArr = (*xT)(i, dimsToExclude); auto ySubArr = (*yT)(i, dimsToExclude); auto zSubArr = (*zT)(i, dimsToExclude); - mmul(&xSubArr, &ySubArr, &zSubArr, 1., 0.); + mmul(&xSubArr, &ySubArr, &zSubArr, alpha, beta); } } diff --git a/libnd4j/include/ops/declarable/generic/blas/matmul.cpp b/libnd4j/include/ops/declarable/generic/blas/matmul.cpp index 6209e7bbf..370aa50c6 100644 --- a/libnd4j/include/ops/declarable/generic/blas/matmul.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/matmul.cpp @@ -36,10 +36,14 @@ CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) { auto y = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); - const int iSize = (int) block.getIArguments()->size(); + int iSize = (int) block.getIArguments()->size(); int transX = iSize > 0 ? INT_ARG(0) : 0; int transY = iSize > 1 ? INT_ARG(1) : 0; const int transZ = iSize > 2 ? INT_ARG(2) : 0; + // optional use alpha nad beta + iSize = (int)block.getTArguments()->size(); + double alpha = iSize > 0 ? T_ARG(0) : 1.0; + double beta = iSize > 1 ? T_ARG(1) : 0.0; const int xRank = x->rankOf(); const int yRank = y->rankOf(); @@ -77,7 +81,7 @@ CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) { } // ******* end of input validation ******* // - MmulHelper::matmul(x, y, z, transX, transY); + MmulHelper::matmul(x, y, z, transX, transY, alpha, beta); return Status::OK(); } @@ -147,11 +151,17 @@ CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) { auto dldx = OUTPUT_VARIABLE(0); auto dldy = OUTPUT_VARIABLE(1); - const int iSize = (int) block.getIArguments()->size(); + int iSize = (int) block.getIArguments()->size(); int transX = iSize > 0 ? INT_ARG(0) : 0; int transY = iSize > 1 ? INT_ARG(1) : 0; const int transZ = iSize > 2 ? INT_ARG(2) : 0; + // optional use alpha nad beta + iSize = (int)block.getTArguments()->size(); + + double alpha = iSize > 0 ? T_ARG(0) : 1.0; + double beta = iSize > 1 ? T_ARG(1) : 0.0; + /* In: x=[a,b], y=[b,c] tX tY tZ x y z dz dLdx dLdy @@ -164,8 +174,8 @@ F F T [a,b] [b,c] [c,a] [c,a] sd::ops::matmul op; - op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {}); - op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {}); + op.execute({eps, y}, {dldx}, {alpha, beta}, {transZ, !transY, transX}, {}); + op.execute({x, eps}, {dldy}, {alpha, beta}, {!transX, transZ, transY}, {}); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp index 91e56d801..f3ef84e2f 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp @@ -32,7 +32,7 @@ namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY) { +static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY, float alpha = 1.f, float beta = 0.f) { // mkl works with following // [M,K] x [K,N] = [M,N] @@ -150,6 +150,12 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b // Create attributes (to handle alpha and beta if necessary) dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0) + if (alpha != 1.f) attr.set_output_scales(0, {alpha}); + if (beta != 0.f) { + dnnl::post_ops po; + po.append_sum(beta); + attr.set_post_ops(po); + } // operation primitive description dnnl::matmul::desc op_desc(x_mkl_md, y_mkl_md, z_mkl_md); @@ -224,11 +230,16 @@ PLATFORM_IMPL(matmul, ENGINE_CPU) { if(x->isEmpty() || y->isEmpty()) return Status::OK(); - const int iSize = (int) block.getIArguments()->size(); + int iSize = (int) block.getIArguments()->size(); int transX = iSize > 0 ? INT_ARG(0) : 0; int transY = iSize > 1 ? INT_ARG(1) : 0; const int transZ = iSize > 2 ? INT_ARG(2) : 0; + // optional use alpha nad beta + iSize = (int)block.getTArguments()->size(); + float alpha = iSize > 0 ? T_ARG(0) : 1.0; + float beta = iSize > 1 ? T_ARG(1) : 0.0; + const int xRank = x->rankOf(); const int yRank = y->rankOf(); const int zRank = z->rankOf(); @@ -265,7 +276,7 @@ PLATFORM_IMPL(matmul, ENGINE_CPU) { } // ******* end of input validation ******* // - matmulMKLDNN(x, y, z, transX, transY); + matmulMKLDNN(x, y, z, transX, transY, alpha, beta); return Status::OK(); } @@ -276,14 +287,16 @@ PLATFORM_CHECK(matmul, ENGINE_CPU) { auto x = INPUT_VARIABLE(0); auto y = INPUT_VARIABLE(1); - auto z = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); const DataType xType = x->dataType(); const DataType yType = y->dataType(); const DataType zType = z->dataType(); + float alpha = block.numT() > 0 ? T_ARG(0) : 1.0; + float beta = block.numT() > 1 ? T_ARG(1) : 0.0; - return block.isUseMKLDNN() && x->rankOf() < 3 && + return !(z->ordering() == 'f' && beta != 0.f) && block.isUseMKLDNN() && x->rankOf() < 3 && ( (xType==DataType::FLOAT32 && yType==DataType::FLOAT32 && zType==DataType::FLOAT32) || (xType==DataType::HALF && yType==DataType::HALF && zType==DataType::FLOAT32) || diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp index 8b2bd0071..f48e3d946 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp @@ -39,6 +39,97 @@ public: } }; +TEST_F(DeclarableOpsTests19, test_matmul_ccc) { + auto x = NDArrayFactory::create('c', {10, 10}); + auto y = NDArrayFactory::create('c', {10, 10}); + auto e = NDArrayFactory::create('c', {10, 10}); + auto z = NDArrayFactory::create('c', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests19, test_matmul_fcf) { + auto x = NDArrayFactory::create('f', {10, 10}); + auto y = NDArrayFactory::create('c', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests19, test_matmul_cff) { + auto x = NDArrayFactory::create('c', {10, 10}); + auto y = NDArrayFactory::create('f', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + + +TEST_F(DeclarableOpsTests19, test_matmul_ccf) { + auto x = NDArrayFactory::create('c', {10, 10}); + auto y = NDArrayFactory::create('c', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests19, test_matmul_fff) { + auto x = NDArrayFactory::create('f', {10, 10}); + auto y = NDArrayFactory::create('f', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) { /* DynamicCustomOp op = DynamicCustomOp.builder("conv1d_bp") @@ -74,4 +165,4 @@ TEST_F(DeclarableOpsTests19, test_squeeze_1) { sd::ops::squeeze op; auto status = op.execute({&x}, {&e}, {axis}); ASSERT_EQ(Status::OK(), status); -} \ No newline at end of file +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel2.java index 824f46c82..8736c2363 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel2.java @@ -19,11 +19,13 @@ package org.nd4j.linalg.api.blas.impl; import lombok.val; import org.nd4j.linalg.api.blas.Level2; import org.nd4j.linalg.api.blas.params.GemvParameters; +import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil; +import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; @@ -57,6 +59,10 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { OpProfiler.getInstance().processBlasCall(false, A, X, Y); GemvParameters parameters = new GemvParameters(A, X, Y); + + Nd4j.exec(new Mmul(A, X, Y, alpha, beta, MMulTranspose.builder().transposeA(false).build())); + + /* if (A.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, parameters.getA(), parameters.getX(), parameters.getY()); @@ -86,7 +92,7 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { } else { throw new ND4JIllegalStateException("Unsupported data type " + A.dataType()); } - + */ OpExecutionerUtil.checkForAny(Y); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel3.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel3.java index e38a0e618..8d9765aee 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel3.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel3.java @@ -19,11 +19,13 @@ package org.nd4j.linalg.api.blas.impl; import lombok.extern.slf4j.Slf4j; import org.nd4j.linalg.api.blas.Level3; import org.nd4j.linalg.api.blas.params.GemmParams; +import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil; +import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.profiler.OpProfiler; @@ -59,6 +61,9 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 { GemmParams params = new GemmParams(A, B, C); + Nd4j.exec(new Mmul(A, B, C, alpha, beta, MMulTranspose.builder().transposeA(false).transposeB(false).build())); + + /* int charOder = Order; if (A.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, params.getA(), params.getB(), params.getC()); @@ -73,6 +78,7 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 { hgemm(Order, params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(), 1.0f, params.getA(), params.getLda(), params.getB(), params.getLdb(), 0, C, params.getLdc()); } + */ OpExecutionerUtil.checkForAny(C); } @@ -85,6 +91,9 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(true, A, B, C); + Nd4j.exec(new Mmul(A, B, C, alpha, beta, MMulTranspose.builder().transposeA(transposeA).transposeB(transposeB).build())); + + /* GemmParams params = new GemmParams(A, B, C, transposeA, transposeB); if (A.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, params.getA(), params.getB(), C); @@ -102,7 +111,7 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 { (float) alpha, params.getA(), params.getLda(), params.getB(), params.getLdb(), (float) beta, C, params.getLdc()); } - +*/ OpExecutionerUtil.checkForAny(C); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 3edbd2682..07a2bf9b8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -2866,16 +2866,22 @@ public abstract class BaseNDArray implements INDArray, Iterable { } @Override - public INDArray mmul(INDArray other) { + public INDArray mmul(INDArray other, char resultOrder) { + Preconditions.checkArgument(resultOrder == 'c' || resultOrder == 'f', "Order must be either 'c' or 'f', but [" + resultOrder + "] was given"); Preconditions.checkState(this.dataType() == other.dataType(), "Matrix multiplication: arrays must have same dtype: %s vs. %s", this.dataType(), other.dataType()); - // FIXME: for 1D case, we probably want vector output here? - long[] shape = {rows(), other.rank() == 1 ? 1 : other.columns()}; - INDArray result = createUninitialized(this.dataType(), shape, 'f'); + // FIXME: add support for 3D+ here? + long[] shape = other.rank() == 1 ? new long[]{rows()} : new long[]{rows(), other.columns()}; + INDArray result = createUninitialized(this.dataType(), shape, resultOrder); if (result.isScalar()) return Nd4j.scalar(this.dataType(), Nd4j.getBlasWrapper().dot(this, other)).reshape(1, 1); return mmuli(other, result); } + @Override + public INDArray mmul(INDArray other) { + return mmul(other, (this.ordering() == 'f' && other.ordering() == 'f' && other.rank() != 1) ? 'f' : 'c'); + } + protected INDArray create(int[] shape, char ordering) { return Nd4j.create(shape, ordering); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index de80e9413..08aa613fb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -1232,6 +1232,14 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray mmul(INDArray other); + /** + * Perform a copy matrix multiplication + * @param other other the other matrix to perform matrix multiply with + * @param resultOrder either C or F order for result array + * @return the result of the matrix multiplication + */ + INDArray mmul(INDArray other, char resultOrder); + /** * Convert this ndarray to a 2d double matrix. * Note that THIS SHOULD NOT BE USED FOR SPEED. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java index 30ca8ebc5..46310893d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java @@ -44,6 +44,8 @@ import java.util.*; public class Mmul extends DynamicCustomOp { protected MMulTranspose mt; + protected double alpha = 1.0; + protected double beta = 0.0; /** * @@ -59,6 +61,7 @@ public class Mmul extends DynamicCustomOp { super(null,sameDiff,new SDVariable[]{i_v1,i_v2}); this.mt = mt; addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), ArrayUtil.fromBoolean(mt.isTransposeB()), ArrayUtil.fromBoolean(mt.isTransposeResult())); + addTArgument(alpha, beta); } @@ -74,6 +77,30 @@ public class Mmul extends DynamicCustomOp { this(sameDiff,i_v1,i_v2,MMulTranspose.allFalse()); } + public Mmul(INDArray x, + INDArray y, + INDArray z, + double alpha, + double beta, + MMulTranspose mt) { + addInputArgument(x, y); + + if (z != null) + addOutputArgument(z); + + if (mt != null) { + this.mt = mt; + addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), + ArrayUtil.fromBoolean(mt.isTransposeB()), + ArrayUtil.fromBoolean(mt.isTransposeResult())); + } + + this.alpha = alpha; + this.beta = beta; + + addTArgument(alpha, beta); + } + /** * * @param x @@ -84,25 +111,30 @@ public class Mmul extends DynamicCustomOp { INDArray y, INDArray z, MMulTranspose mt) { - super(null, new INDArray[]{x, y}, z == null ? null : new INDArray[]{z}); - if (mt != null) { - this.mt = mt; - addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), - ArrayUtil.fromBoolean(mt.isTransposeB()), - ArrayUtil.fromBoolean(mt.isTransposeResult())); - } + this(x, y, z, 1.0, 0.0, mt); } public Mmul(INDArray x, INDArray y, boolean transposeX, boolean transposeY, boolean transposeZ) { + this(x, y, 1.0, 0.0, transposeX, transposeY, transposeZ); + } + + public Mmul(INDArray x, INDArray y, double alpha, double beta, boolean transposeX, boolean transposeY, boolean transposeZ) { addInputArgument(x, y); addIArgument(ArrayUtil.fromBoolean(transposeX), ArrayUtil.fromBoolean(transposeY), ArrayUtil.fromBoolean(transposeZ)); mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build(); + addTArgument(alpha, beta); + this.alpha = alpha; + this.beta = beta; + } + + public Mmul(INDArray x, INDArray y, double alpha, double beta) { + this(x,y,null, alpha, beta,null); } public Mmul(INDArray x, INDArray y) { - this(x,y,null,null); + this(x, y, 1.0, 0.0); } public Mmul(SameDiff sameDiff, SDVariable x, SDVariable y, boolean transposeX, boolean transposeY, @@ -111,6 +143,8 @@ public class Mmul extends DynamicCustomOp { addIArgument(ArrayUtil.fromBoolean(transposeX), ArrayUtil.fromBoolean(transposeY), ArrayUtil.fromBoolean(transposeZ)); + + addTArgument(alpha, beta); mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 5da64dadb..43181a3b2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -791,7 +791,7 @@ public class Nd4j { boolean transposeB) { long cRows = (transposeA ? a.columns() : a.rows()); long cCols = (transposeB ? b.rows() : b.columns()); - INDArray c = Nd4j.createUninitialized(a.dataType(), new long[] {cRows, cCols}, 'f'); + INDArray c = Nd4j.createUninitialized(a.dataType(), new long[] {cRows, cCols}, a.ordering() == 'c' && b.ordering() == 'c' ? 'c' : 'f'); return gemm(a, b, c, transposeA, transposeB, 1.0, 0.0); } @@ -817,12 +817,9 @@ public class Nd4j { boolean transposeB, double alpha, double beta) { - //Note: some views have non-zero offset but 'default' strides (these are OK). And a 'c' order vector such as [10,1] is OK - same buffer as an 'f' order vector with same shape - Preconditions.checkState(c.length() == 1 || c.ordering() == 'f' && Shape.hasDefaultStridesForShape(c) || - c.isVectorOrScalar() && c.elementWiseStride() == 1, - "C (result) array is not F order or is a view. Nd4j.gemm requires the result array to be F order " + - "and not a view. C (result) array: [%ndSInfo]", c); - getBlasWrapper().level3().gemm(a, b, c, transposeA, transposeB, alpha, beta); + Preconditions.checkArgument(c.elementWiseStride() == 1, "Nd4j.gemm() C array should NOT be a view"); + + Nd4j.exec(new Mmul(a, b, c, alpha, beta, MMulTranspose.builder().transposeA(transposeA).transposeB(transposeB).build())); return c; } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java index 423887b64..997bf609c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java @@ -40,6 +40,11 @@ public class CheckpointListenerTest extends BaseNd4jTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); + @Override + public long getTimeoutMilliseconds() { + return 90000L; + } + public static SameDiff getModel(){ Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); @@ -151,7 +156,7 @@ public class CheckpointListenerTest extends BaseNd4jTest { CheckpointListener l = new CheckpointListener.Builder(dir) .keepLast(2) - .saveEvery(1, TimeUnit.SECONDS) + .saveEvery(4, TimeUnit.SECONDS) .build(); sd.setListeners(l); @@ -159,7 +164,7 @@ public class CheckpointListenerTest extends BaseNd4jTest { for(int i=0; i<5; i++ ){ //10 iterations total sd.fit(iter, 1); - Thread.sleep(1000); + Thread.sleep(5000); } //Expect models saved at iterations: 10, 20, 30, 40 diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index da91fb6cf..162e123b8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -6192,24 +6192,6 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(exp, output); } - @Test - public void testVectorGemv() { - val vectorL = Nd4j.create(new float[]{1, 2, 3}, new long[]{3, 1}); - val vectorN = Nd4j.create(new float[]{1, 2, 3}, new long[]{3}); - val matrix = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, new long[] {3, 3}); - -// log.info("vectorN: {}", vectorN); -// log.info("vectorL: {}", vectorL); - - val outN = matrix.mmul(vectorN); - val outL = matrix.mmul(vectorL); - - assertEquals(outL, outN.reshape(3,1)); - - assertEquals(1, outN.rank()); - } - - @Test public void testMatrixReshape() { val matrix = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, new long[] {3, 3}); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level3Test.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level3Test.java index 8263cc07c..8fc247683 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level3Test.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level3Test.java @@ -60,7 +60,7 @@ public class Level3Test extends BaseNd4jTest { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10); - INDArray array3 = array1.mmul(array2); + INDArray array3 = array1.mmul(array2, Nd4j.createUninitialized(new long[]{10, 10}, 'f')); //System.out.println("Array3: " + Arrays.toString(array3.data().asFloat())); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java index 7dcd2285f..b3bd78979 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java @@ -83,7 +83,7 @@ public class DataTypeValidationTests extends BaseNd4jTest { /** * Testing level2 blas */ - @Test(expected = ND4JIllegalStateException.class) + @Test(expected = RuntimeException.class) public void testBlasValidation2() { INDArray a = Nd4j.create(100, 10); INDArray x = Nd4j.create(100); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/blas/BlasTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/blas/BlasTests.java index b81b90133..d9f307abe 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/blas/BlasTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/blas/BlasTests.java @@ -83,22 +83,7 @@ public class BlasTests extends BaseNd4jTest { try { Nd4j.gemm(a, b, view, false, false, 1.0, 0.0); fail("Expected exception"); - } catch (IllegalStateException e) { - assertTrue(e.getMessage().contains("view")); - } - } - - @Test - public void testGemmInvalid2() { - final INDArray a = Nd4j.rand(4, 3); - final INDArray b = Nd4j.rand(4, 5); - - final INDArray target = Nd4j.zeros(3, 5, 'c'); - - try { - Nd4j.gemm(a, b, target, true, false, 1.0, 0.0); - fail("Expected exception"); - } catch (IllegalStateException e) { + } catch (IllegalArgumentException e) { assertTrue(e.getMessage().contains("view")); } } @@ -114,7 +99,7 @@ public class BlasTests extends BaseNd4jTest { try { Nd4j.gemm(a, b, view, true, false, 1.0, 0.0); fail("Expected exception"); - } catch (IllegalStateException e) { + } catch (IllegalArgumentException e) { assertTrue(e.getMessage().contains("view")); } }