MatMul for gemm/gemv calls (#365)

* libnd4j added optional alpha and beta support to matmul

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j typos fixes

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j add optional alpha and beta to matmul_bp

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j one more typo fix

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j added optional alpha and beta to mkl implementation

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* MatMul alpha/beta on java side

Signed-off-by: raver119 <raver119@gmail.com>

* alpha/beta fix in libnd4j

Signed-off-by: raver119 <raver119@gmail.com>

* alpha/beta fix in matmul_bp

Signed-off-by: raver119 <raver119@gmail.com>

* restored view validation

Signed-off-by: raver119 <raver119@gmail.com>

* gemv/gemm now use MatMul op

Signed-off-by: raver119 <raver119@gmail.com>

* few tests fixed

Signed-off-by: raver119 <raver119@gmail.com>

* additional INDArray.mmul signature

Signed-off-by: raver119 <raver119@gmail.com>

* make C order default for INDArray.mmul, unless both A/B have F order

Signed-off-by: raver119 <raver119@gmail.com>

* Nd4j.gemm validation fix

Signed-off-by: raver119 <raver119@gmail.com>

* disable mkldnn matmul for xxf with beta != 0 case

Signed-off-by: raver119 <raver119@gmail.com>

* SimpleRnn workspace fix + timeouts

Signed-off-by: Alex Black <blacka101@gmail.com>

* two more tests + minor fix in matmul platform check

Signed-off-by: raver119 <raver119@gmail.com>

* Flaky test fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* propagate testresources profile

Signed-off-by: raver119 <raver119@gmail.com>

* Resources fix + flaky test fix

Signed-off-by: Alex Black <blacka101@gmail.com>

Co-authored-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Alex Black <blacka101@gmail.com>
master
raver119 2020-04-10 17:57:02 +03:00 committed by GitHub
parent 99c727f15b
commit 3e2dbc65dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 279 additions and 85 deletions

View File

@ -23,6 +23,11 @@ import org.junit.Test;
public class TestDataSets extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L;
}
@Test
public void testTinyImageNetExists() throws Exception {
//Simple sanity check on extracting

View File

@ -44,6 +44,11 @@ import static org.junit.Assert.*;
public class TestCheckpointListener extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Rule
public TemporaryFolder tempDir = new TemporaryFolder();
@ -57,7 +62,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DataSetIterator iter = new IrisDataSetIterator(75,150);
DataSetIterator iter = new IrisDataSetIterator(25,50);
return new Pair<>(net, iter);
}
@ -178,13 +183,13 @@ public class TestCheckpointListener extends BaseDL4JTest {
CheckpointListener l = new CheckpointListener.Builder(f)
.keepLast(3)
.saveEvery(3, TimeUnit.SECONDS)
.saveEvery(4, TimeUnit.SECONDS)
.build();
net.setListeners(l);
for(int i=0; i<5; i++ ){ //10 iterations total
net.fit(iter);
Thread.sleep(4000);
Thread.sleep(5000);
}
//Expect models saved at iterations: 2, 4, 6, 8 (iterations 0 and 1 shoud happen before first 3 seconds is up)

View File

@ -54,6 +54,11 @@ import static org.junit.Assert.*;
@Slf4j
public class RegressionTest100a extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override
public DataType getDataType(){
return DataType.FLOAT;

View File

@ -52,6 +52,11 @@ import static org.junit.Assert.*;
public class RegressionTest100b3 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override
public DataType getDataType(){
return DataType.FLOAT;

View File

@ -69,6 +69,11 @@ import org.nd4j.resources.Resources;
public class RegressionTest100b4 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override
public DataType getDataType() {
return DataType.FLOAT;
@ -123,7 +128,8 @@ public class RegressionTest100b4 extends BaseDL4JTest {
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
assertEquals(dtype, net.params().dataType());
assertEquals("Test for dtype: " + dtypeName, outExp, outAct);
boolean eq = outExp.equalsWithEps(outAct, 0.01);
assertTrue("Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct, eq);
}
}

View File

@ -56,6 +56,11 @@ public class RegressionTest100b6 extends BaseDL4JTest {
return DataType.FLOAT;
}
@Override
public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test
public void testCustomLayer() throws Exception {
@ -106,7 +111,8 @@ public class RegressionTest100b6 extends BaseDL4JTest {
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
assertEquals(dtype, net.params().dataType());
boolean eq = outExp.equalsWithEps(outAct, 0.01);
assertTrue(outExp + " vs " + outAct, eq); }
assertTrue(outExp + " vs " + outAct, eq);
}
}

View File

@ -96,6 +96,11 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
}
};
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test(expected = IllegalStateException.class)
public void fileNotFoundEndToEnd() throws Exception {
String modelPath = "modelimport/keras/examples/foo/bar.h5";

View File

@ -72,7 +72,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
INDArray out = activateHelper(last, training, false, workspaceMgr).getFirst();
if(storeLastForTBPTT){
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
tBpttStateMap.put(STATE_KEY_PREV_ACTIVATION, out.get(all(), all(), point(out.size(2)-1)));
tBpttStateMap.put(STATE_KEY_PREV_ACTIVATION, out.get(all(), all(), point(out.size(2)-1)).dup());
}
}
return out;

View File

@ -84,6 +84,13 @@
<profiles>
<profile>
<id>testresources</id>
<activation>
<activeByDefault>true</activeByDefault>
</activation>
</profile>
<profile>
<id>test-nd4j-native</id>
<activation>

View File

@ -19,6 +19,7 @@ import org.nd4j.remote.clients.serde.BinarySerializer;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.remote.clients.serde.impl.IntegerSerde;
import org.nd4j.resources.Resources;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import javax.imageio.ImageIO;
@ -65,7 +66,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
@Test
public void testMlnMnist_ImageInput() throws Exception {
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
val modelFile = Resources.asFile("models/mnist/mnist-model.zip");
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
@ -129,7 +130,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
@Test
public void testMlnMnist_ImageInput_Async() throws Exception {
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
val modelFile = Resources.asFile("models/mnist/mnist-model.zip");
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
@ -198,7 +199,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
@Test
public void testBinaryIn_BinaryOut() throws Exception {
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
val modelFile = Resources.asFile("models/mnist/mnist-model.zip");
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
val server = new JsonModelServer.Builder<BufferedImage, BufferedImage>(net)

View File

@ -20,6 +20,9 @@
<name>deeplearning4j-remote</name>
<profiles>
<profile>
<id>testresources</id>
</profile>
<profile>
<id>test-nd4j-native</id>
</profile>

View File

@ -60,7 +60,7 @@ namespace sd {
static sd::NDArray* tensorDot(const sd::NDArray* a, const sd::NDArray* b, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB);
#endif
static void matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY);
static void matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY, double alpha = 1.0, double beta = 0.0);
};
}

View File

@ -239,7 +239,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
//////////////////////////////////////////////////////////////////////////
void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY) {
void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY, double alpha, double beta) {
int xRank = x->rankOf();
int yRank = y->rankOf();
@ -276,7 +276,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
zT = new NDArray(z->reshape(z->ordering(), {1, z->lengthOf()}));
}
mmul(xT, yT, zT, 1., 0.);
mmul(xT, yT, zT, alpha, beta);
}
else { // rest cases - batched mmul
@ -292,7 +292,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
auto xSubArr = (*xT)(i, dimsToExclude);
auto ySubArr = (*yT)(i, dimsToExclude);
auto zSubArr = (*zT)(i, dimsToExclude);
mmul(&xSubArr, &ySubArr, &zSubArr, 1., 0.);
mmul(&xSubArr, &ySubArr, &zSubArr, alpha, beta);
}
}

View File

@ -36,10 +36,14 @@ CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
const int iSize = (int) block.getIArguments()->size();
int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
// optional use alpha nad beta
iSize = (int)block.getTArguments()->size();
double alpha = iSize > 0 ? T_ARG(0) : 1.0;
double beta = iSize > 1 ? T_ARG(1) : 0.0;
const int xRank = x->rankOf();
const int yRank = y->rankOf();
@ -77,7 +81,7 @@ CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
}
// ******* end of input validation ******* //
MmulHelper::matmul(x, y, z, transX, transY);
MmulHelper::matmul(x, y, z, transX, transY, alpha, beta);
return Status::OK();
}
@ -147,11 +151,17 @@ CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) {
auto dldx = OUTPUT_VARIABLE(0);
auto dldy = OUTPUT_VARIABLE(1);
const int iSize = (int) block.getIArguments()->size();
int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
// optional use alpha nad beta
iSize = (int)block.getTArguments()->size();
double alpha = iSize > 0 ? T_ARG(0) : 1.0;
double beta = iSize > 1 ? T_ARG(1) : 0.0;
/*
In: x=[a,b], y=[b,c]
tX tY tZ x y z dz dLdx dLdy
@ -164,8 +174,8 @@ F F T [a,b] [b,c] [c,a] [c,a]
sd::ops::matmul op;
op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {});
op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {});
op.execute({eps, y}, {dldx}, {alpha, beta}, {transZ, !transY, transX}, {});
op.execute({x, eps}, {dldy}, {alpha, beta}, {!transX, transZ, transY}, {});
return Status::OK();
}

View File

@ -32,7 +32,7 @@ namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////////
static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY) {
static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY, float alpha = 1.f, float beta = 0.f) {
// mkl works with following
// [M,K] x [K,N] = [M,N]
@ -150,6 +150,12 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
// Create attributes (to handle alpha and beta if necessary)
dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0)
if (alpha != 1.f) attr.set_output_scales(0, {alpha});
if (beta != 0.f) {
dnnl::post_ops po;
po.append_sum(beta);
attr.set_post_ops(po);
}
// operation primitive description
dnnl::matmul::desc op_desc(x_mkl_md, y_mkl_md, z_mkl_md);
@ -224,11 +230,16 @@ PLATFORM_IMPL(matmul, ENGINE_CPU) {
if(x->isEmpty() || y->isEmpty())
return Status::OK();
const int iSize = (int) block.getIArguments()->size();
int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
// optional use alpha nad beta
iSize = (int)block.getTArguments()->size();
float alpha = iSize > 0 ? T_ARG(0) : 1.0;
float beta = iSize > 1 ? T_ARG(1) : 0.0;
const int xRank = x->rankOf();
const int yRank = y->rankOf();
const int zRank = z->rankOf();
@ -265,7 +276,7 @@ PLATFORM_IMPL(matmul, ENGINE_CPU) {
}
// ******* end of input validation ******* //
matmulMKLDNN(x, y, z, transX, transY);
matmulMKLDNN(x, y, z, transX, transY, alpha, beta);
return Status::OK();
}
@ -276,14 +287,16 @@ PLATFORM_CHECK(matmul, ENGINE_CPU) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0);
const DataType xType = x->dataType();
const DataType yType = y->dataType();
const DataType zType = z->dataType();
float alpha = block.numT() > 0 ? T_ARG(0) : 1.0;
float beta = block.numT() > 1 ? T_ARG(1) : 0.0;
return block.isUseMKLDNN() && x->rankOf() < 3 &&
return !(z->ordering() == 'f' && beta != 0.f) && block.isUseMKLDNN() && x->rankOf() < 3 &&
(
(xType==DataType::FLOAT32 && yType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
(xType==DataType::HALF && yType==DataType::HALF && zType==DataType::FLOAT32) ||

View File

@ -39,6 +39,97 @@ public:
}
};
TEST_F(DeclarableOpsTests19, test_matmul_ccc) {
auto x = NDArrayFactory::create<float>('c', {10, 10});
auto y = NDArrayFactory::create<float>('c', {10, 10});
auto e = NDArrayFactory::create<float>('c', {10, 10});
auto z = NDArrayFactory::create<float>('c', {10, 10});
z.assign(100.f);
e.assign(110.f);
x.assign(1.0f);
y.assign(1.0f);
sd::ops::matmul op;
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
}
TEST_F(DeclarableOpsTests19, test_matmul_fcf) {
auto x = NDArrayFactory::create<float>('f', {10, 10});
auto y = NDArrayFactory::create<float>('c', {10, 10});
auto e = NDArrayFactory::create<float>('f', {10, 10});
auto z = NDArrayFactory::create<float>('f', {10, 10});
z.assign(100.f);
e.assign(110.f);
x.assign(1.0f);
y.assign(1.0f);
sd::ops::matmul op;
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
}
TEST_F(DeclarableOpsTests19, test_matmul_cff) {
auto x = NDArrayFactory::create<float>('c', {10, 10});
auto y = NDArrayFactory::create<float>('f', {10, 10});
auto e = NDArrayFactory::create<float>('f', {10, 10});
auto z = NDArrayFactory::create<float>('f', {10, 10});
z.assign(100.f);
e.assign(110.f);
x.assign(1.0f);
y.assign(1.0f);
sd::ops::matmul op;
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
}
TEST_F(DeclarableOpsTests19, test_matmul_ccf) {
auto x = NDArrayFactory::create<float>('c', {10, 10});
auto y = NDArrayFactory::create<float>('c', {10, 10});
auto e = NDArrayFactory::create<float>('f', {10, 10});
auto z = NDArrayFactory::create<float>('f', {10, 10});
z.assign(100.f);
e.assign(110.f);
x.assign(1.0f);
y.assign(1.0f);
sd::ops::matmul op;
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
}
TEST_F(DeclarableOpsTests19, test_matmul_fff) {
auto x = NDArrayFactory::create<float>('f', {10, 10});
auto y = NDArrayFactory::create<float>('f', {10, 10});
auto e = NDArrayFactory::create<float>('f', {10, 10});
auto z = NDArrayFactory::create<float>('f', {10, 10});
z.assign(100.f);
e.assign(110.f);
x.assign(1.0f);
y.assign(1.0f);
sd::ops::matmul op;
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
}
TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) {
/*
DynamicCustomOp op = DynamicCustomOp.builder("conv1d_bp")

View File

@ -19,11 +19,13 @@ package org.nd4j.linalg.api.blas.impl;
import lombok.val;
import org.nd4j.linalg.api.blas.Level2;
import org.nd4j.linalg.api.blas.params.GemvParameters;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
@ -57,6 +59,10 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 {
OpProfiler.getInstance().processBlasCall(false, A, X, Y);
GemvParameters parameters = new GemvParameters(A, X, Y);
Nd4j.exec(new Mmul(A, X, Y, alpha, beta, MMulTranspose.builder().transposeA(false).build()));
/*
if (A.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, parameters.getA(), parameters.getX(),
parameters.getY());
@ -86,7 +92,7 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 {
} else {
throw new ND4JIllegalStateException("Unsupported data type " + A.dataType());
}
*/
OpExecutionerUtil.checkForAny(Y);
}

View File

@ -19,11 +19,13 @@ package org.nd4j.linalg.api.blas.impl;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.blas.Level3;
import org.nd4j.linalg.api.blas.params.GemmParams;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.OpProfiler;
@ -59,6 +61,9 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 {
GemmParams params = new GemmParams(A, B, C);
Nd4j.exec(new Mmul(A, B, C, alpha, beta, MMulTranspose.builder().transposeA(false).transposeB(false).build()));
/*
int charOder = Order;
if (A.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, params.getA(), params.getB(), params.getC());
@ -73,6 +78,7 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 {
hgemm(Order, params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(), 1.0f,
params.getA(), params.getLda(), params.getB(), params.getLdb(), 0, C, params.getLdc());
}
*/
OpExecutionerUtil.checkForAny(C);
}
@ -85,6 +91,9 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(true, A, B, C);
Nd4j.exec(new Mmul(A, B, C, alpha, beta, MMulTranspose.builder().transposeA(transposeA).transposeB(transposeB).build()));
/*
GemmParams params = new GemmParams(A, B, C, transposeA, transposeB);
if (A.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, params.getA(), params.getB(), C);
@ -102,7 +111,7 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 {
(float) alpha, params.getA(), params.getLda(), params.getB(), params.getLdb(), (float) beta,
C, params.getLdc());
}
*/
OpExecutionerUtil.checkForAny(C);
}

View File

@ -2866,16 +2866,22 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
@Override
public INDArray mmul(INDArray other) {
public INDArray mmul(INDArray other, char resultOrder) {
Preconditions.checkArgument(resultOrder == 'c' || resultOrder == 'f', "Order must be either 'c' or 'f', but [" + resultOrder + "] was given");
Preconditions.checkState(this.dataType() == other.dataType(), "Matrix multiplication: arrays must have same dtype: %s vs. %s", this.dataType(), other.dataType());
// FIXME: for 1D case, we probably want vector output here?
long[] shape = {rows(), other.rank() == 1 ? 1 : other.columns()};
INDArray result = createUninitialized(this.dataType(), shape, 'f');
// FIXME: add support for 3D+ here?
long[] shape = other.rank() == 1 ? new long[]{rows()} : new long[]{rows(), other.columns()};
INDArray result = createUninitialized(this.dataType(), shape, resultOrder);
if (result.isScalar())
return Nd4j.scalar(this.dataType(), Nd4j.getBlasWrapper().dot(this, other)).reshape(1, 1);
return mmuli(other, result);
}
@Override
public INDArray mmul(INDArray other) {
return mmul(other, (this.ordering() == 'f' && other.ordering() == 'f' && other.rank() != 1) ? 'f' : 'c');
}
protected INDArray create(int[] shape, char ordering) {
return Nd4j.create(shape, ordering);
}

View File

@ -1232,6 +1232,14 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray mmul(INDArray other);
/**
* Perform a copy matrix multiplication
* @param other other the other matrix to perform matrix multiply with
* @param resultOrder either C or F order for result array
* @return the result of the matrix multiplication
*/
INDArray mmul(INDArray other, char resultOrder);
/**
* Convert this ndarray to a 2d double matrix.
* Note that THIS SHOULD NOT BE USED FOR SPEED.

View File

@ -44,6 +44,8 @@ import java.util.*;
public class Mmul extends DynamicCustomOp {
protected MMulTranspose mt;
protected double alpha = 1.0;
protected double beta = 0.0;
/**
*
@ -59,6 +61,7 @@ public class Mmul extends DynamicCustomOp {
super(null,sameDiff,new SDVariable[]{i_v1,i_v2});
this.mt = mt;
addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), ArrayUtil.fromBoolean(mt.isTransposeB()), ArrayUtil.fromBoolean(mt.isTransposeResult()));
addTArgument(alpha, beta);
}
@ -74,6 +77,30 @@ public class Mmul extends DynamicCustomOp {
this(sameDiff,i_v1,i_v2,MMulTranspose.allFalse());
}
public Mmul(INDArray x,
INDArray y,
INDArray z,
double alpha,
double beta,
MMulTranspose mt) {
addInputArgument(x, y);
if (z != null)
addOutputArgument(z);
if (mt != null) {
this.mt = mt;
addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()),
ArrayUtil.fromBoolean(mt.isTransposeB()),
ArrayUtil.fromBoolean(mt.isTransposeResult()));
}
this.alpha = alpha;
this.beta = beta;
addTArgument(alpha, beta);
}
/**
*
* @param x
@ -84,25 +111,30 @@ public class Mmul extends DynamicCustomOp {
INDArray y,
INDArray z,
MMulTranspose mt) {
super(null, new INDArray[]{x, y}, z == null ? null : new INDArray[]{z});
if (mt != null) {
this.mt = mt;
addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()),
ArrayUtil.fromBoolean(mt.isTransposeB()),
ArrayUtil.fromBoolean(mt.isTransposeResult()));
}
this(x, y, z, 1.0, 0.0, mt);
}
public Mmul(INDArray x, INDArray y, boolean transposeX, boolean transposeY, boolean transposeZ) {
this(x, y, 1.0, 0.0, transposeX, transposeY, transposeZ);
}
public Mmul(INDArray x, INDArray y, double alpha, double beta, boolean transposeX, boolean transposeY, boolean transposeZ) {
addInputArgument(x, y);
addIArgument(ArrayUtil.fromBoolean(transposeX),
ArrayUtil.fromBoolean(transposeY),
ArrayUtil.fromBoolean(transposeZ));
mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build();
addTArgument(alpha, beta);
this.alpha = alpha;
this.beta = beta;
}
public Mmul(INDArray x, INDArray y, double alpha, double beta) {
this(x,y,null, alpha, beta,null);
}
public Mmul(INDArray x, INDArray y) {
this(x,y,null,null);
this(x, y, 1.0, 0.0);
}
public Mmul(SameDiff sameDiff, SDVariable x, SDVariable y, boolean transposeX, boolean transposeY,
@ -111,6 +143,8 @@ public class Mmul extends DynamicCustomOp {
addIArgument(ArrayUtil.fromBoolean(transposeX),
ArrayUtil.fromBoolean(transposeY),
ArrayUtil.fromBoolean(transposeZ));
addTArgument(alpha, beta);
mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build();
}

View File

@ -791,7 +791,7 @@ public class Nd4j {
boolean transposeB) {
long cRows = (transposeA ? a.columns() : a.rows());
long cCols = (transposeB ? b.rows() : b.columns());
INDArray c = Nd4j.createUninitialized(a.dataType(), new long[] {cRows, cCols}, 'f');
INDArray c = Nd4j.createUninitialized(a.dataType(), new long[] {cRows, cCols}, a.ordering() == 'c' && b.ordering() == 'c' ? 'c' : 'f');
return gemm(a, b, c, transposeA, transposeB, 1.0, 0.0);
}
@ -817,12 +817,9 @@ public class Nd4j {
boolean transposeB,
double alpha,
double beta) {
//Note: some views have non-zero offset but 'default' strides (these are OK). And a 'c' order vector such as [10,1] is OK - same buffer as an 'f' order vector with same shape
Preconditions.checkState(c.length() == 1 || c.ordering() == 'f' && Shape.hasDefaultStridesForShape(c) ||
c.isVectorOrScalar() && c.elementWiseStride() == 1,
"C (result) array is not F order or is a view. Nd4j.gemm requires the result array to be F order " +
"and not a view. C (result) array: [%ndSInfo]", c);
getBlasWrapper().level3().gemm(a, b, c, transposeA, transposeB, alpha, beta);
Preconditions.checkArgument(c.elementWiseStride() == 1, "Nd4j.gemm() C array should NOT be a view");
Nd4j.exec(new Mmul(a, b, c, alpha, beta, MMulTranspose.builder().transposeA(transposeA).transposeB(transposeB).build()));
return c;
}

View File

@ -40,6 +40,11 @@ public class CheckpointListenerTest extends BaseNd4jTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
public static SameDiff getModel(){
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
@ -151,7 +156,7 @@ public class CheckpointListenerTest extends BaseNd4jTest {
CheckpointListener l = new CheckpointListener.Builder(dir)
.keepLast(2)
.saveEvery(1, TimeUnit.SECONDS)
.saveEvery(4, TimeUnit.SECONDS)
.build();
sd.setListeners(l);
@ -159,7 +164,7 @@ public class CheckpointListenerTest extends BaseNd4jTest {
for(int i=0; i<5; i++ ){ //10 iterations total
sd.fit(iter, 1);
Thread.sleep(1000);
Thread.sleep(5000);
}
//Expect models saved at iterations: 10, 20, 30, 40

View File

@ -6192,24 +6192,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(exp, output);
}
@Test
public void testVectorGemv() {
val vectorL = Nd4j.create(new float[]{1, 2, 3}, new long[]{3, 1});
val vectorN = Nd4j.create(new float[]{1, 2, 3}, new long[]{3});
val matrix = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, new long[] {3, 3});
// log.info("vectorN: {}", vectorN);
// log.info("vectorL: {}", vectorL);
val outN = matrix.mmul(vectorN);
val outL = matrix.mmul(vectorL);
assertEquals(outL, outN.reshape(3,1));
assertEquals(1, outN.rank());
}
@Test
public void testMatrixReshape() {
val matrix = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, new long[] {3, 3});

View File

@ -60,7 +60,7 @@ public class Level3Test extends BaseNd4jTest {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
INDArray array3 = array1.mmul(array2);
INDArray array3 = array1.mmul(array2, Nd4j.createUninitialized(new long[]{10, 10}, 'f'));
//System.out.println("Array3: " + Arrays.toString(array3.data().asFloat()));

View File

@ -83,7 +83,7 @@ public class DataTypeValidationTests extends BaseNd4jTest {
/**
* Testing level2 blas
*/
@Test(expected = ND4JIllegalStateException.class)
@Test(expected = RuntimeException.class)
public void testBlasValidation2() {
INDArray a = Nd4j.create(100, 10);
INDArray x = Nd4j.create(100);

View File

@ -83,22 +83,7 @@ public class BlasTests extends BaseNd4jTest {
try {
Nd4j.gemm(a, b, view, false, false, 1.0, 0.0);
fail("Expected exception");
} catch (IllegalStateException e) {
assertTrue(e.getMessage().contains("view"));
}
}
@Test
public void testGemmInvalid2() {
final INDArray a = Nd4j.rand(4, 3);
final INDArray b = Nd4j.rand(4, 5);
final INDArray target = Nd4j.zeros(3, 5, 'c');
try {
Nd4j.gemm(a, b, target, true, false, 1.0, 0.0);
fail("Expected exception");
} catch (IllegalStateException e) {
} catch (IllegalArgumentException e) {
assertTrue(e.getMessage().contains("view"));
}
}
@ -114,7 +99,7 @@ public class BlasTests extends BaseNd4jTest {
try {
Nd4j.gemm(a, b, view, true, false, 1.0, 0.0);
fail("Expected exception");
} catch (IllegalStateException e) {
} catch (IllegalArgumentException e) {
assertTrue(e.getMessage().contains("view"));
}
}