MatMul for gemm/gemv calls (#365)
* libnd4j added optional alpha and beta support to matmul Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j typos fixes Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j add optional alpha and beta to matmul_bp Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j one more typo fix Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j added optional alpha and beta to mkl implementation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * MatMul alpha/beta on java side Signed-off-by: raver119 <raver119@gmail.com> * alpha/beta fix in libnd4j Signed-off-by: raver119 <raver119@gmail.com> * alpha/beta fix in matmul_bp Signed-off-by: raver119 <raver119@gmail.com> * restored view validation Signed-off-by: raver119 <raver119@gmail.com> * gemv/gemm now use MatMul op Signed-off-by: raver119 <raver119@gmail.com> * few tests fixed Signed-off-by: raver119 <raver119@gmail.com> * additional INDArray.mmul signature Signed-off-by: raver119 <raver119@gmail.com> * make C order default for INDArray.mmul, unless both A/B have F order Signed-off-by: raver119 <raver119@gmail.com> * Nd4j.gemm validation fix Signed-off-by: raver119 <raver119@gmail.com> * disable mkldnn matmul for xxf with beta != 0 case Signed-off-by: raver119 <raver119@gmail.com> * SimpleRnn workspace fix + timeouts Signed-off-by: Alex Black <blacka101@gmail.com> * two more tests + minor fix in matmul platform check Signed-off-by: raver119 <raver119@gmail.com> * Flaky test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * propagate testresources profile Signed-off-by: raver119 <raver119@gmail.com> * Resources fix + flaky test fix Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>master
parent
99c727f15b
commit
3e2dbc65dd
|
@ -23,6 +23,11 @@ import org.junit.Test;
|
|||
|
||||
public class TestDataSets extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 180000L;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTinyImageNetExists() throws Exception {
|
||||
//Simple sanity check on extracting
|
||||
|
|
|
@ -44,6 +44,11 @@ import static org.junit.Assert.*;
|
|||
|
||||
public class TestCheckpointListener extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L;
|
||||
}
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder tempDir = new TemporaryFolder();
|
||||
|
||||
|
@ -57,7 +62,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
|
|||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
DataSetIterator iter = new IrisDataSetIterator(75,150);
|
||||
DataSetIterator iter = new IrisDataSetIterator(25,50);
|
||||
|
||||
return new Pair<>(net, iter);
|
||||
}
|
||||
|
@ -178,13 +183,13 @@ public class TestCheckpointListener extends BaseDL4JTest {
|
|||
|
||||
CheckpointListener l = new CheckpointListener.Builder(f)
|
||||
.keepLast(3)
|
||||
.saveEvery(3, TimeUnit.SECONDS)
|
||||
.saveEvery(4, TimeUnit.SECONDS)
|
||||
.build();
|
||||
net.setListeners(l);
|
||||
|
||||
for(int i=0; i<5; i++ ){ //10 iterations total
|
||||
net.fit(iter);
|
||||
Thread.sleep(4000);
|
||||
Thread.sleep(5000);
|
||||
}
|
||||
|
||||
//Expect models saved at iterations: 2, 4, 6, 8 (iterations 0 and 1 shoud happen before first 3 seconds is up)
|
||||
|
|
|
@ -54,6 +54,11 @@ import static org.junit.Assert.*;
|
|||
@Slf4j
|
||||
public class RegressionTest100a extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDataType(){
|
||||
return DataType.FLOAT;
|
||||
|
|
|
@ -52,6 +52,11 @@ import static org.junit.Assert.*;
|
|||
|
||||
public class RegressionTest100b3 extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDataType(){
|
||||
return DataType.FLOAT;
|
||||
|
|
|
@ -69,6 +69,11 @@ import org.nd4j.resources.Resources;
|
|||
|
||||
public class RegressionTest100b4 extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDataType() {
|
||||
return DataType.FLOAT;
|
||||
|
@ -123,7 +128,8 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
|||
|
||||
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
|
||||
assertEquals(dtype, net.params().dataType());
|
||||
assertEquals("Test for dtype: " + dtypeName, outExp, outAct);
|
||||
boolean eq = outExp.equalsWithEps(outAct, 0.01);
|
||||
assertTrue("Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct, eq);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -56,6 +56,11 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
|||
return DataType.FLOAT;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCustomLayer() throws Exception {
|
||||
|
||||
|
@ -106,7 +111,8 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
|||
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
|
||||
assertEquals(dtype, net.params().dataType());
|
||||
boolean eq = outExp.equalsWithEps(outAct, 0.01);
|
||||
assertTrue(outExp + " vs " + outAct, eq); }
|
||||
assertTrue(outExp + " vs " + outAct, eq);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -96,6 +96,11 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
}
|
||||
};
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L;
|
||||
}
|
||||
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void fileNotFoundEndToEnd() throws Exception {
|
||||
String modelPath = "modelimport/keras/examples/foo/bar.h5";
|
||||
|
|
|
@ -72,7 +72,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
|||
INDArray out = activateHelper(last, training, false, workspaceMgr).getFirst();
|
||||
if(storeLastForTBPTT){
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
||||
tBpttStateMap.put(STATE_KEY_PREV_ACTIVATION, out.get(all(), all(), point(out.size(2)-1)));
|
||||
tBpttStateMap.put(STATE_KEY_PREV_ACTIVATION, out.get(all(), all(), point(out.size(2)-1)).dup());
|
||||
}
|
||||
}
|
||||
return out;
|
||||
|
|
|
@ -84,6 +84,13 @@
|
|||
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>testresources</id>
|
||||
<activation>
|
||||
<activeByDefault>true</activeByDefault>
|
||||
</activation>
|
||||
</profile>
|
||||
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<activation>
|
||||
|
|
|
@ -19,6 +19,7 @@ import org.nd4j.remote.clients.serde.BinarySerializer;
|
|||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
import org.nd4j.remote.clients.serde.impl.IntegerSerde;
|
||||
import org.nd4j.resources.Resources;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
|
@ -65,7 +66,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
|
|||
@Test
|
||||
public void testMlnMnist_ImageInput() throws Exception {
|
||||
|
||||
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
|
||||
val modelFile = Resources.asFile("models/mnist/mnist-model.zip");
|
||||
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||||
|
||||
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
|
||||
|
@ -129,7 +130,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
|
|||
@Test
|
||||
public void testMlnMnist_ImageInput_Async() throws Exception {
|
||||
|
||||
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
|
||||
val modelFile = Resources.asFile("models/mnist/mnist-model.zip");
|
||||
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||||
|
||||
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
|
||||
|
@ -198,7 +199,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
|
|||
@Test
|
||||
public void testBinaryIn_BinaryOut() throws Exception {
|
||||
|
||||
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
|
||||
val modelFile = Resources.asFile("models/mnist/mnist-model.zip");
|
||||
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||||
|
||||
val server = new JsonModelServer.Builder<BufferedImage, BufferedImage>(net)
|
||||
|
|
|
@ -20,6 +20,9 @@
|
|||
<name>deeplearning4j-remote</name>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>testresources</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
|
|
|
@ -60,7 +60,7 @@ namespace sd {
|
|||
static sd::NDArray* tensorDot(const sd::NDArray* a, const sd::NDArray* b, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB);
|
||||
#endif
|
||||
|
||||
static void matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY);
|
||||
static void matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY, double alpha = 1.0, double beta = 0.0);
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -239,7 +239,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
|
|||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY) {
|
||||
void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY, double alpha, double beta) {
|
||||
int xRank = x->rankOf();
|
||||
int yRank = y->rankOf();
|
||||
|
||||
|
@ -276,7 +276,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
|
|||
zT = new NDArray(z->reshape(z->ordering(), {1, z->lengthOf()}));
|
||||
}
|
||||
|
||||
mmul(xT, yT, zT, 1., 0.);
|
||||
mmul(xT, yT, zT, alpha, beta);
|
||||
}
|
||||
else { // rest cases - batched mmul
|
||||
|
||||
|
@ -292,7 +292,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
|
|||
auto xSubArr = (*xT)(i, dimsToExclude);
|
||||
auto ySubArr = (*yT)(i, dimsToExclude);
|
||||
auto zSubArr = (*zT)(i, dimsToExclude);
|
||||
mmul(&xSubArr, &ySubArr, &zSubArr, 1., 0.);
|
||||
mmul(&xSubArr, &ySubArr, &zSubArr, alpha, beta);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -36,10 +36,14 @@ CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
|
|||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
const int iSize = (int) block.getIArguments()->size();
|
||||
int iSize = (int) block.getIArguments()->size();
|
||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||
// optional use alpha nad beta
|
||||
iSize = (int)block.getTArguments()->size();
|
||||
double alpha = iSize > 0 ? T_ARG(0) : 1.0;
|
||||
double beta = iSize > 1 ? T_ARG(1) : 0.0;
|
||||
|
||||
const int xRank = x->rankOf();
|
||||
const int yRank = y->rankOf();
|
||||
|
@ -77,7 +81,7 @@ CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
|
|||
}
|
||||
// ******* end of input validation ******* //
|
||||
|
||||
MmulHelper::matmul(x, y, z, transX, transY);
|
||||
MmulHelper::matmul(x, y, z, transX, transY, alpha, beta);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -147,11 +151,17 @@ CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) {
|
|||
auto dldx = OUTPUT_VARIABLE(0);
|
||||
auto dldy = OUTPUT_VARIABLE(1);
|
||||
|
||||
const int iSize = (int) block.getIArguments()->size();
|
||||
int iSize = (int) block.getIArguments()->size();
|
||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||
|
||||
// optional use alpha nad beta
|
||||
iSize = (int)block.getTArguments()->size();
|
||||
|
||||
double alpha = iSize > 0 ? T_ARG(0) : 1.0;
|
||||
double beta = iSize > 1 ? T_ARG(1) : 0.0;
|
||||
|
||||
/*
|
||||
In: x=[a,b], y=[b,c]
|
||||
tX tY tZ x y z dz dLdx dLdy
|
||||
|
@ -164,8 +174,8 @@ F F T [a,b] [b,c] [c,a] [c,a]
|
|||
|
||||
|
||||
sd::ops::matmul op;
|
||||
op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {});
|
||||
op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {});
|
||||
op.execute({eps, y}, {dldx}, {alpha, beta}, {transZ, !transY, transX}, {});
|
||||
op.execute({x, eps}, {dldy}, {alpha, beta}, {!transX, transZ, transY}, {});
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace ops {
|
|||
namespace platforms {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY) {
|
||||
static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY, float alpha = 1.f, float beta = 0.f) {
|
||||
|
||||
// mkl works with following
|
||||
// [M,K] x [K,N] = [M,N]
|
||||
|
@ -150,6 +150,12 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
|
|||
|
||||
// Create attributes (to handle alpha and beta if necessary)
|
||||
dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0)
|
||||
if (alpha != 1.f) attr.set_output_scales(0, {alpha});
|
||||
if (beta != 0.f) {
|
||||
dnnl::post_ops po;
|
||||
po.append_sum(beta);
|
||||
attr.set_post_ops(po);
|
||||
}
|
||||
|
||||
// operation primitive description
|
||||
dnnl::matmul::desc op_desc(x_mkl_md, y_mkl_md, z_mkl_md);
|
||||
|
@ -224,11 +230,16 @@ PLATFORM_IMPL(matmul, ENGINE_CPU) {
|
|||
if(x->isEmpty() || y->isEmpty())
|
||||
return Status::OK();
|
||||
|
||||
const int iSize = (int) block.getIArguments()->size();
|
||||
int iSize = (int) block.getIArguments()->size();
|
||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||
|
||||
// optional use alpha nad beta
|
||||
iSize = (int)block.getTArguments()->size();
|
||||
float alpha = iSize > 0 ? T_ARG(0) : 1.0;
|
||||
float beta = iSize > 1 ? T_ARG(1) : 0.0;
|
||||
|
||||
const int xRank = x->rankOf();
|
||||
const int yRank = y->rankOf();
|
||||
const int zRank = z->rankOf();
|
||||
|
@ -265,7 +276,7 @@ PLATFORM_IMPL(matmul, ENGINE_CPU) {
|
|||
}
|
||||
// ******* end of input validation ******* //
|
||||
|
||||
matmulMKLDNN(x, y, z, transX, transY);
|
||||
matmulMKLDNN(x, y, z, transX, transY, alpha, beta);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -276,14 +287,16 @@ PLATFORM_CHECK(matmul, ENGINE_CPU) {
|
|||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
|
||||
auto z = INPUT_VARIABLE(0);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
const DataType xType = x->dataType();
|
||||
const DataType yType = y->dataType();
|
||||
const DataType zType = z->dataType();
|
||||
|
||||
float alpha = block.numT() > 0 ? T_ARG(0) : 1.0;
|
||||
float beta = block.numT() > 1 ? T_ARG(1) : 0.0;
|
||||
|
||||
return block.isUseMKLDNN() && x->rankOf() < 3 &&
|
||||
return !(z->ordering() == 'f' && beta != 0.f) && block.isUseMKLDNN() && x->rankOf() < 3 &&
|
||||
(
|
||||
(xType==DataType::FLOAT32 && yType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
||||
(xType==DataType::HALF && yType==DataType::HALF && zType==DataType::FLOAT32) ||
|
||||
|
|
|
@ -39,6 +39,97 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
TEST_F(DeclarableOpsTests19, test_matmul_ccc) {
|
||||
auto x = NDArrayFactory::create<float>('c', {10, 10});
|
||||
auto y = NDArrayFactory::create<float>('c', {10, 10});
|
||||
auto e = NDArrayFactory::create<float>('c', {10, 10});
|
||||
auto z = NDArrayFactory::create<float>('c', {10, 10});
|
||||
|
||||
z.assign(100.f);
|
||||
e.assign(110.f);
|
||||
x.assign(1.0f);
|
||||
y.assign(1.0f);
|
||||
|
||||
sd::ops::matmul op;
|
||||
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests19, test_matmul_fcf) {
|
||||
auto x = NDArrayFactory::create<float>('f', {10, 10});
|
||||
auto y = NDArrayFactory::create<float>('c', {10, 10});
|
||||
auto e = NDArrayFactory::create<float>('f', {10, 10});
|
||||
auto z = NDArrayFactory::create<float>('f', {10, 10});
|
||||
|
||||
z.assign(100.f);
|
||||
e.assign(110.f);
|
||||
x.assign(1.0f);
|
||||
y.assign(1.0f);
|
||||
|
||||
sd::ops::matmul op;
|
||||
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests19, test_matmul_cff) {
|
||||
auto x = NDArrayFactory::create<float>('c', {10, 10});
|
||||
auto y = NDArrayFactory::create<float>('f', {10, 10});
|
||||
auto e = NDArrayFactory::create<float>('f', {10, 10});
|
||||
auto z = NDArrayFactory::create<float>('f', {10, 10});
|
||||
|
||||
z.assign(100.f);
|
||||
e.assign(110.f);
|
||||
x.assign(1.0f);
|
||||
y.assign(1.0f);
|
||||
|
||||
sd::ops::matmul op;
|
||||
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests19, test_matmul_ccf) {
|
||||
auto x = NDArrayFactory::create<float>('c', {10, 10});
|
||||
auto y = NDArrayFactory::create<float>('c', {10, 10});
|
||||
auto e = NDArrayFactory::create<float>('f', {10, 10});
|
||||
auto z = NDArrayFactory::create<float>('f', {10, 10});
|
||||
|
||||
z.assign(100.f);
|
||||
e.assign(110.f);
|
||||
x.assign(1.0f);
|
||||
y.assign(1.0f);
|
||||
|
||||
sd::ops::matmul op;
|
||||
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests19, test_matmul_fff) {
|
||||
auto x = NDArrayFactory::create<float>('f', {10, 10});
|
||||
auto y = NDArrayFactory::create<float>('f', {10, 10});
|
||||
auto e = NDArrayFactory::create<float>('f', {10, 10});
|
||||
auto z = NDArrayFactory::create<float>('f', {10, 10});
|
||||
|
||||
z.assign(100.f);
|
||||
e.assign(110.f);
|
||||
x.assign(1.0f);
|
||||
y.assign(1.0f);
|
||||
|
||||
sd::ops::matmul op;
|
||||
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) {
|
||||
/*
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("conv1d_bp")
|
||||
|
|
|
@ -19,11 +19,13 @@ package org.nd4j.linalg.api.blas.impl;
|
|||
import lombok.val;
|
||||
import org.nd4j.linalg.api.blas.Level2;
|
||||
import org.nd4j.linalg.api.blas.params.GemvParameters;
|
||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
|
||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
||||
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -57,6 +59,10 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 {
|
|||
OpProfiler.getInstance().processBlasCall(false, A, X, Y);
|
||||
|
||||
GemvParameters parameters = new GemvParameters(A, X, Y);
|
||||
|
||||
Nd4j.exec(new Mmul(A, X, Y, alpha, beta, MMulTranspose.builder().transposeA(false).build()));
|
||||
|
||||
/*
|
||||
if (A.data().dataType() == DataType.DOUBLE) {
|
||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, parameters.getA(), parameters.getX(),
|
||||
parameters.getY());
|
||||
|
@ -86,7 +92,7 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 {
|
|||
} else {
|
||||
throw new ND4JIllegalStateException("Unsupported data type " + A.dataType());
|
||||
}
|
||||
|
||||
*/
|
||||
OpExecutionerUtil.checkForAny(Y);
|
||||
}
|
||||
|
||||
|
|
|
@ -19,11 +19,13 @@ package org.nd4j.linalg.api.blas.impl;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.linalg.api.blas.Level3;
|
||||
import org.nd4j.linalg.api.blas.params.GemmParams;
|
||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
|
||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
||||
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.profiler.OpProfiler;
|
||||
|
@ -59,6 +61,9 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 {
|
|||
|
||||
GemmParams params = new GemmParams(A, B, C);
|
||||
|
||||
Nd4j.exec(new Mmul(A, B, C, alpha, beta, MMulTranspose.builder().transposeA(false).transposeB(false).build()));
|
||||
|
||||
/*
|
||||
int charOder = Order;
|
||||
if (A.data().dataType() == DataType.DOUBLE) {
|
||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, params.getA(), params.getB(), params.getC());
|
||||
|
@ -73,6 +78,7 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 {
|
|||
hgemm(Order, params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(), 1.0f,
|
||||
params.getA(), params.getLda(), params.getB(), params.getLdb(), 0, C, params.getLdc());
|
||||
}
|
||||
*/
|
||||
|
||||
OpExecutionerUtil.checkForAny(C);
|
||||
}
|
||||
|
@ -85,6 +91,9 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 {
|
|||
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
||||
OpProfiler.getInstance().processBlasCall(true, A, B, C);
|
||||
|
||||
Nd4j.exec(new Mmul(A, B, C, alpha, beta, MMulTranspose.builder().transposeA(transposeA).transposeB(transposeB).build()));
|
||||
|
||||
/*
|
||||
GemmParams params = new GemmParams(A, B, C, transposeA, transposeB);
|
||||
if (A.data().dataType() == DataType.DOUBLE) {
|
||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, params.getA(), params.getB(), C);
|
||||
|
@ -102,7 +111,7 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 {
|
|||
(float) alpha, params.getA(), params.getLda(), params.getB(), params.getLdb(), (float) beta,
|
||||
C, params.getLdc());
|
||||
}
|
||||
|
||||
*/
|
||||
OpExecutionerUtil.checkForAny(C);
|
||||
}
|
||||
|
||||
|
|
|
@ -2866,16 +2866,22 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
}
|
||||
|
||||
@Override
|
||||
public INDArray mmul(INDArray other) {
|
||||
public INDArray mmul(INDArray other, char resultOrder) {
|
||||
Preconditions.checkArgument(resultOrder == 'c' || resultOrder == 'f', "Order must be either 'c' or 'f', but [" + resultOrder + "] was given");
|
||||
Preconditions.checkState(this.dataType() == other.dataType(), "Matrix multiplication: arrays must have same dtype: %s vs. %s", this.dataType(), other.dataType());
|
||||
// FIXME: for 1D case, we probably want vector output here?
|
||||
long[] shape = {rows(), other.rank() == 1 ? 1 : other.columns()};
|
||||
INDArray result = createUninitialized(this.dataType(), shape, 'f');
|
||||
// FIXME: add support for 3D+ here?
|
||||
long[] shape = other.rank() == 1 ? new long[]{rows()} : new long[]{rows(), other.columns()};
|
||||
INDArray result = createUninitialized(this.dataType(), shape, resultOrder);
|
||||
if (result.isScalar())
|
||||
return Nd4j.scalar(this.dataType(), Nd4j.getBlasWrapper().dot(this, other)).reshape(1, 1);
|
||||
return mmuli(other, result);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray mmul(INDArray other) {
|
||||
return mmul(other, (this.ordering() == 'f' && other.ordering() == 'f' && other.rank() != 1) ? 'f' : 'c');
|
||||
}
|
||||
|
||||
protected INDArray create(int[] shape, char ordering) {
|
||||
return Nd4j.create(shape, ordering);
|
||||
}
|
||||
|
|
|
@ -1232,6 +1232,14 @@ public interface INDArray extends Serializable, AutoCloseable {
|
|||
*/
|
||||
INDArray mmul(INDArray other);
|
||||
|
||||
/**
|
||||
* Perform a copy matrix multiplication
|
||||
* @param other other the other matrix to perform matrix multiply with
|
||||
* @param resultOrder either C or F order for result array
|
||||
* @return the result of the matrix multiplication
|
||||
*/
|
||||
INDArray mmul(INDArray other, char resultOrder);
|
||||
|
||||
/**
|
||||
* Convert this ndarray to a 2d double matrix.
|
||||
* Note that THIS SHOULD NOT BE USED FOR SPEED.
|
||||
|
|
|
@ -44,6 +44,8 @@ import java.util.*;
|
|||
public class Mmul extends DynamicCustomOp {
|
||||
|
||||
protected MMulTranspose mt;
|
||||
protected double alpha = 1.0;
|
||||
protected double beta = 0.0;
|
||||
|
||||
/**
|
||||
*
|
||||
|
@ -59,6 +61,7 @@ public class Mmul extends DynamicCustomOp {
|
|||
super(null,sameDiff,new SDVariable[]{i_v1,i_v2});
|
||||
this.mt = mt;
|
||||
addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), ArrayUtil.fromBoolean(mt.isTransposeB()), ArrayUtil.fromBoolean(mt.isTransposeResult()));
|
||||
addTArgument(alpha, beta);
|
||||
}
|
||||
|
||||
|
||||
|
@ -74,6 +77,30 @@ public class Mmul extends DynamicCustomOp {
|
|||
this(sameDiff,i_v1,i_v2,MMulTranspose.allFalse());
|
||||
}
|
||||
|
||||
public Mmul(INDArray x,
|
||||
INDArray y,
|
||||
INDArray z,
|
||||
double alpha,
|
||||
double beta,
|
||||
MMulTranspose mt) {
|
||||
addInputArgument(x, y);
|
||||
|
||||
if (z != null)
|
||||
addOutputArgument(z);
|
||||
|
||||
if (mt != null) {
|
||||
this.mt = mt;
|
||||
addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()),
|
||||
ArrayUtil.fromBoolean(mt.isTransposeB()),
|
||||
ArrayUtil.fromBoolean(mt.isTransposeResult()));
|
||||
}
|
||||
|
||||
this.alpha = alpha;
|
||||
this.beta = beta;
|
||||
|
||||
addTArgument(alpha, beta);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param x
|
||||
|
@ -84,25 +111,30 @@ public class Mmul extends DynamicCustomOp {
|
|||
INDArray y,
|
||||
INDArray z,
|
||||
MMulTranspose mt) {
|
||||
super(null, new INDArray[]{x, y}, z == null ? null : new INDArray[]{z});
|
||||
if (mt != null) {
|
||||
this.mt = mt;
|
||||
addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()),
|
||||
ArrayUtil.fromBoolean(mt.isTransposeB()),
|
||||
ArrayUtil.fromBoolean(mt.isTransposeResult()));
|
||||
}
|
||||
this(x, y, z, 1.0, 0.0, mt);
|
||||
}
|
||||
|
||||
public Mmul(INDArray x, INDArray y, boolean transposeX, boolean transposeY, boolean transposeZ) {
|
||||
this(x, y, 1.0, 0.0, transposeX, transposeY, transposeZ);
|
||||
}
|
||||
|
||||
public Mmul(INDArray x, INDArray y, double alpha, double beta, boolean transposeX, boolean transposeY, boolean transposeZ) {
|
||||
addInputArgument(x, y);
|
||||
addIArgument(ArrayUtil.fromBoolean(transposeX),
|
||||
ArrayUtil.fromBoolean(transposeY),
|
||||
ArrayUtil.fromBoolean(transposeZ));
|
||||
mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build();
|
||||
addTArgument(alpha, beta);
|
||||
this.alpha = alpha;
|
||||
this.beta = beta;
|
||||
}
|
||||
|
||||
public Mmul(INDArray x, INDArray y, double alpha, double beta) {
|
||||
this(x,y,null, alpha, beta,null);
|
||||
}
|
||||
|
||||
public Mmul(INDArray x, INDArray y) {
|
||||
this(x,y,null,null);
|
||||
this(x, y, 1.0, 0.0);
|
||||
}
|
||||
|
||||
public Mmul(SameDiff sameDiff, SDVariable x, SDVariable y, boolean transposeX, boolean transposeY,
|
||||
|
@ -111,6 +143,8 @@ public class Mmul extends DynamicCustomOp {
|
|||
addIArgument(ArrayUtil.fromBoolean(transposeX),
|
||||
ArrayUtil.fromBoolean(transposeY),
|
||||
ArrayUtil.fromBoolean(transposeZ));
|
||||
|
||||
addTArgument(alpha, beta);
|
||||
mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build();
|
||||
}
|
||||
|
||||
|
|
|
@ -791,7 +791,7 @@ public class Nd4j {
|
|||
boolean transposeB) {
|
||||
long cRows = (transposeA ? a.columns() : a.rows());
|
||||
long cCols = (transposeB ? b.rows() : b.columns());
|
||||
INDArray c = Nd4j.createUninitialized(a.dataType(), new long[] {cRows, cCols}, 'f');
|
||||
INDArray c = Nd4j.createUninitialized(a.dataType(), new long[] {cRows, cCols}, a.ordering() == 'c' && b.ordering() == 'c' ? 'c' : 'f');
|
||||
return gemm(a, b, c, transposeA, transposeB, 1.0, 0.0);
|
||||
}
|
||||
|
||||
|
@ -817,12 +817,9 @@ public class Nd4j {
|
|||
boolean transposeB,
|
||||
double alpha,
|
||||
double beta) {
|
||||
//Note: some views have non-zero offset but 'default' strides (these are OK). And a 'c' order vector such as [10,1] is OK - same buffer as an 'f' order vector with same shape
|
||||
Preconditions.checkState(c.length() == 1 || c.ordering() == 'f' && Shape.hasDefaultStridesForShape(c) ||
|
||||
c.isVectorOrScalar() && c.elementWiseStride() == 1,
|
||||
"C (result) array is not F order or is a view. Nd4j.gemm requires the result array to be F order " +
|
||||
"and not a view. C (result) array: [%ndSInfo]", c);
|
||||
getBlasWrapper().level3().gemm(a, b, c, transposeA, transposeB, alpha, beta);
|
||||
Preconditions.checkArgument(c.elementWiseStride() == 1, "Nd4j.gemm() C array should NOT be a view");
|
||||
|
||||
Nd4j.exec(new Mmul(a, b, c, alpha, beta, MMulTranspose.builder().transposeA(transposeA).transposeB(transposeB).build()));
|
||||
return c;
|
||||
}
|
||||
|
||||
|
|
|
@ -40,6 +40,11 @@ public class CheckpointListenerTest extends BaseNd4jTest {
|
|||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L;
|
||||
}
|
||||
|
||||
public static SameDiff getModel(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -151,7 +156,7 @@ public class CheckpointListenerTest extends BaseNd4jTest {
|
|||
|
||||
CheckpointListener l = new CheckpointListener.Builder(dir)
|
||||
.keepLast(2)
|
||||
.saveEvery(1, TimeUnit.SECONDS)
|
||||
.saveEvery(4, TimeUnit.SECONDS)
|
||||
.build();
|
||||
sd.setListeners(l);
|
||||
|
||||
|
@ -159,7 +164,7 @@ public class CheckpointListenerTest extends BaseNd4jTest {
|
|||
|
||||
for(int i=0; i<5; i++ ){ //10 iterations total
|
||||
sd.fit(iter, 1);
|
||||
Thread.sleep(1000);
|
||||
Thread.sleep(5000);
|
||||
}
|
||||
|
||||
//Expect models saved at iterations: 10, 20, 30, 40
|
||||
|
|
|
@ -6192,24 +6192,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
assertEquals(exp, output);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorGemv() {
|
||||
val vectorL = Nd4j.create(new float[]{1, 2, 3}, new long[]{3, 1});
|
||||
val vectorN = Nd4j.create(new float[]{1, 2, 3}, new long[]{3});
|
||||
val matrix = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, new long[] {3, 3});
|
||||
|
||||
// log.info("vectorN: {}", vectorN);
|
||||
// log.info("vectorL: {}", vectorL);
|
||||
|
||||
val outN = matrix.mmul(vectorN);
|
||||
val outL = matrix.mmul(vectorL);
|
||||
|
||||
assertEquals(outL, outN.reshape(3,1));
|
||||
|
||||
assertEquals(1, outN.rank());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testMatrixReshape() {
|
||||
val matrix = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, new long[] {3, 3});
|
||||
|
|
|
@ -60,7 +60,7 @@ public class Level3Test extends BaseNd4jTest {
|
|||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
|
||||
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
|
||||
|
||||
INDArray array3 = array1.mmul(array2);
|
||||
INDArray array3 = array1.mmul(array2, Nd4j.createUninitialized(new long[]{10, 10}, 'f'));
|
||||
|
||||
|
||||
//System.out.println("Array3: " + Arrays.toString(array3.data().asFloat()));
|
||||
|
|
|
@ -83,7 +83,7 @@ public class DataTypeValidationTests extends BaseNd4jTest {
|
|||
/**
|
||||
* Testing level2 blas
|
||||
*/
|
||||
@Test(expected = ND4JIllegalStateException.class)
|
||||
@Test(expected = RuntimeException.class)
|
||||
public void testBlasValidation2() {
|
||||
INDArray a = Nd4j.create(100, 10);
|
||||
INDArray x = Nd4j.create(100);
|
||||
|
|
|
@ -83,22 +83,7 @@ public class BlasTests extends BaseNd4jTest {
|
|||
try {
|
||||
Nd4j.gemm(a, b, view, false, false, 1.0, 0.0);
|
||||
fail("Expected exception");
|
||||
} catch (IllegalStateException e) {
|
||||
assertTrue(e.getMessage().contains("view"));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGemmInvalid2() {
|
||||
final INDArray a = Nd4j.rand(4, 3);
|
||||
final INDArray b = Nd4j.rand(4, 5);
|
||||
|
||||
final INDArray target = Nd4j.zeros(3, 5, 'c');
|
||||
|
||||
try {
|
||||
Nd4j.gemm(a, b, target, true, false, 1.0, 0.0);
|
||||
fail("Expected exception");
|
||||
} catch (IllegalStateException e) {
|
||||
} catch (IllegalArgumentException e) {
|
||||
assertTrue(e.getMessage().contains("view"));
|
||||
}
|
||||
}
|
||||
|
@ -114,7 +99,7 @@ public class BlasTests extends BaseNd4jTest {
|
|||
try {
|
||||
Nd4j.gemm(a, b, view, true, false, 1.0, 0.0);
|
||||
fail("Expected exception");
|
||||
} catch (IllegalStateException e) {
|
||||
} catch (IllegalArgumentException e) {
|
||||
assertTrue(e.getMessage().contains("view"));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue