MatMul for gemm/gemv calls (#365)
* libnd4j added optional alpha and beta support to matmul Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j typos fixes Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j add optional alpha and beta to matmul_bp Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j one more typo fix Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j added optional alpha and beta to mkl implementation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * MatMul alpha/beta on java side Signed-off-by: raver119 <raver119@gmail.com> * alpha/beta fix in libnd4j Signed-off-by: raver119 <raver119@gmail.com> * alpha/beta fix in matmul_bp Signed-off-by: raver119 <raver119@gmail.com> * restored view validation Signed-off-by: raver119 <raver119@gmail.com> * gemv/gemm now use MatMul op Signed-off-by: raver119 <raver119@gmail.com> * few tests fixed Signed-off-by: raver119 <raver119@gmail.com> * additional INDArray.mmul signature Signed-off-by: raver119 <raver119@gmail.com> * make C order default for INDArray.mmul, unless both A/B have F order Signed-off-by: raver119 <raver119@gmail.com> * Nd4j.gemm validation fix Signed-off-by: raver119 <raver119@gmail.com> * disable mkldnn matmul for xxf with beta != 0 case Signed-off-by: raver119 <raver119@gmail.com> * SimpleRnn workspace fix + timeouts Signed-off-by: Alex Black <blacka101@gmail.com> * two more tests + minor fix in matmul platform check Signed-off-by: raver119 <raver119@gmail.com> * Flaky test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * propagate testresources profile Signed-off-by: raver119 <raver119@gmail.com> * Resources fix + flaky test fix Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>master
parent
99c727f15b
commit
3e2dbc65dd
|
@ -23,6 +23,11 @@ import org.junit.Test;
|
||||||
|
|
||||||
public class TestDataSets extends BaseDL4JTest {
|
public class TestDataSets extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 180000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testTinyImageNetExists() throws Exception {
|
public void testTinyImageNetExists() throws Exception {
|
||||||
//Simple sanity check on extracting
|
//Simple sanity check on extracting
|
||||||
|
|
|
@ -44,6 +44,11 @@ import static org.junit.Assert.*;
|
||||||
|
|
||||||
public class TestCheckpointListener extends BaseDL4JTest {
|
public class TestCheckpointListener extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder tempDir = new TemporaryFolder();
|
public TemporaryFolder tempDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
@ -57,7 +62,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
DataSetIterator iter = new IrisDataSetIterator(75,150);
|
DataSetIterator iter = new IrisDataSetIterator(25,50);
|
||||||
|
|
||||||
return new Pair<>(net, iter);
|
return new Pair<>(net, iter);
|
||||||
}
|
}
|
||||||
|
@ -178,13 +183,13 @@ public class TestCheckpointListener extends BaseDL4JTest {
|
||||||
|
|
||||||
CheckpointListener l = new CheckpointListener.Builder(f)
|
CheckpointListener l = new CheckpointListener.Builder(f)
|
||||||
.keepLast(3)
|
.keepLast(3)
|
||||||
.saveEvery(3, TimeUnit.SECONDS)
|
.saveEvery(4, TimeUnit.SECONDS)
|
||||||
.build();
|
.build();
|
||||||
net.setListeners(l);
|
net.setListeners(l);
|
||||||
|
|
||||||
for(int i=0; i<5; i++ ){ //10 iterations total
|
for(int i=0; i<5; i++ ){ //10 iterations total
|
||||||
net.fit(iter);
|
net.fit(iter);
|
||||||
Thread.sleep(4000);
|
Thread.sleep(5000);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Expect models saved at iterations: 2, 4, 6, 8 (iterations 0 and 1 shoud happen before first 3 seconds is up)
|
//Expect models saved at iterations: 2, 4, 6, 8 (iterations 0 and 1 shoud happen before first 3 seconds is up)
|
||||||
|
|
|
@ -54,6 +54,11 @@ import static org.junit.Assert.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class RegressionTest100a extends BaseDL4JTest {
|
public class RegressionTest100a extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType getDataType(){
|
public DataType getDataType(){
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
|
|
|
@ -52,6 +52,11 @@ import static org.junit.Assert.*;
|
||||||
|
|
||||||
public class RegressionTest100b3 extends BaseDL4JTest {
|
public class RegressionTest100b3 extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType getDataType(){
|
public DataType getDataType(){
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
|
|
|
@ -69,6 +69,11 @@ import org.nd4j.resources.Resources;
|
||||||
|
|
||||||
public class RegressionTest100b4 extends BaseDL4JTest {
|
public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType getDataType() {
|
public DataType getDataType() {
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
|
@ -123,7 +128,8 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
|
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
|
||||||
assertEquals(dtype, net.params().dataType());
|
assertEquals(dtype, net.params().dataType());
|
||||||
assertEquals("Test for dtype: " + dtypeName, outExp, outAct);
|
boolean eq = outExp.equalsWithEps(outAct, 0.01);
|
||||||
|
assertTrue("Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct, eq);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -56,6 +56,11 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCustomLayer() throws Exception {
|
public void testCustomLayer() throws Exception {
|
||||||
|
|
||||||
|
@ -106,7 +111,8 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
|
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
|
||||||
assertEquals(dtype, net.params().dataType());
|
assertEquals(dtype, net.params().dataType());
|
||||||
boolean eq = outExp.equalsWithEps(outAct, 0.01);
|
boolean eq = outExp.equalsWithEps(outAct, 0.01);
|
||||||
assertTrue(outExp + " vs " + outAct, eq); }
|
assertTrue(outExp + " vs " + outAct, eq);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -96,6 +96,11 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalStateException.class)
|
@Test(expected = IllegalStateException.class)
|
||||||
public void fileNotFoundEndToEnd() throws Exception {
|
public void fileNotFoundEndToEnd() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/foo/bar.h5";
|
String modelPath = "modelimport/keras/examples/foo/bar.h5";
|
||||||
|
|
|
@ -72,7 +72,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
||||||
INDArray out = activateHelper(last, training, false, workspaceMgr).getFirst();
|
INDArray out = activateHelper(last, training, false, workspaceMgr).getFirst();
|
||||||
if(storeLastForTBPTT){
|
if(storeLastForTBPTT){
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
||||||
tBpttStateMap.put(STATE_KEY_PREV_ACTIVATION, out.get(all(), all(), point(out.size(2)-1)));
|
tBpttStateMap.put(STATE_KEY_PREV_ACTIVATION, out.get(all(), all(), point(out.size(2)-1)).dup());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
|
|
|
@ -84,6 +84,13 @@
|
||||||
|
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
<profile>
|
||||||
|
<id>testresources</id>
|
||||||
|
<activation>
|
||||||
|
<activeByDefault>true</activeByDefault>
|
||||||
|
</activation>
|
||||||
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
<activation>
|
<activation>
|
||||||
|
|
|
@ -19,6 +19,7 @@ import org.nd4j.remote.clients.serde.BinarySerializer;
|
||||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||||
import org.nd4j.remote.clients.serde.impl.IntegerSerde;
|
import org.nd4j.remote.clients.serde.impl.IntegerSerde;
|
||||||
|
import org.nd4j.resources.Resources;
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
import javax.imageio.ImageIO;
|
import javax.imageio.ImageIO;
|
||||||
|
@ -65,7 +66,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testMlnMnist_ImageInput() throws Exception {
|
public void testMlnMnist_ImageInput() throws Exception {
|
||||||
|
|
||||||
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
|
val modelFile = Resources.asFile("models/mnist/mnist-model.zip");
|
||||||
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||||||
|
|
||||||
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
|
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
|
||||||
|
@ -129,7 +130,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testMlnMnist_ImageInput_Async() throws Exception {
|
public void testMlnMnist_ImageInput_Async() throws Exception {
|
||||||
|
|
||||||
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
|
val modelFile = Resources.asFile("models/mnist/mnist-model.zip");
|
||||||
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||||||
|
|
||||||
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
|
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
|
||||||
|
@ -198,7 +199,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testBinaryIn_BinaryOut() throws Exception {
|
public void testBinaryIn_BinaryOut() throws Exception {
|
||||||
|
|
||||||
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
|
val modelFile = Resources.asFile("models/mnist/mnist-model.zip");
|
||||||
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||||||
|
|
||||||
val server = new JsonModelServer.Builder<BufferedImage, BufferedImage>(net)
|
val server = new JsonModelServer.Builder<BufferedImage, BufferedImage>(net)
|
||||||
|
|
|
@ -20,6 +20,9 @@
|
||||||
<name>deeplearning4j-remote</name>
|
<name>deeplearning4j-remote</name>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
<profile>
|
||||||
|
<id>testresources</id>
|
||||||
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
|
|
|
@ -60,7 +60,7 @@ namespace sd {
|
||||||
static sd::NDArray* tensorDot(const sd::NDArray* a, const sd::NDArray* b, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB);
|
static sd::NDArray* tensorDot(const sd::NDArray* a, const sd::NDArray* b, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static void matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY);
|
static void matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY, double alpha = 1.0, double beta = 0.0);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -239,7 +239,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY) {
|
void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY, double alpha, double beta) {
|
||||||
int xRank = x->rankOf();
|
int xRank = x->rankOf();
|
||||||
int yRank = y->rankOf();
|
int yRank = y->rankOf();
|
||||||
|
|
||||||
|
@ -276,7 +276,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
|
||||||
zT = new NDArray(z->reshape(z->ordering(), {1, z->lengthOf()}));
|
zT = new NDArray(z->reshape(z->ordering(), {1, z->lengthOf()}));
|
||||||
}
|
}
|
||||||
|
|
||||||
mmul(xT, yT, zT, 1., 0.);
|
mmul(xT, yT, zT, alpha, beta);
|
||||||
}
|
}
|
||||||
else { // rest cases - batched mmul
|
else { // rest cases - batched mmul
|
||||||
|
|
||||||
|
@ -292,7 +292,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
|
||||||
auto xSubArr = (*xT)(i, dimsToExclude);
|
auto xSubArr = (*xT)(i, dimsToExclude);
|
||||||
auto ySubArr = (*yT)(i, dimsToExclude);
|
auto ySubArr = (*yT)(i, dimsToExclude);
|
||||||
auto zSubArr = (*zT)(i, dimsToExclude);
|
auto zSubArr = (*zT)(i, dimsToExclude);
|
||||||
mmul(&xSubArr, &ySubArr, &zSubArr, 1., 0.);
|
mmul(&xSubArr, &ySubArr, &zSubArr, alpha, beta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,10 +36,14 @@ CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
const int iSize = (int) block.getIArguments()->size();
|
int iSize = (int) block.getIArguments()->size();
|
||||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||||
|
// optional use alpha nad beta
|
||||||
|
iSize = (int)block.getTArguments()->size();
|
||||||
|
double alpha = iSize > 0 ? T_ARG(0) : 1.0;
|
||||||
|
double beta = iSize > 1 ? T_ARG(1) : 0.0;
|
||||||
|
|
||||||
const int xRank = x->rankOf();
|
const int xRank = x->rankOf();
|
||||||
const int yRank = y->rankOf();
|
const int yRank = y->rankOf();
|
||||||
|
@ -77,7 +81,7 @@ CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
|
||||||
}
|
}
|
||||||
// ******* end of input validation ******* //
|
// ******* end of input validation ******* //
|
||||||
|
|
||||||
MmulHelper::matmul(x, y, z, transX, transY);
|
MmulHelper::matmul(x, y, z, transX, transY, alpha, beta);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -147,11 +151,17 @@ CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) {
|
||||||
auto dldx = OUTPUT_VARIABLE(0);
|
auto dldx = OUTPUT_VARIABLE(0);
|
||||||
auto dldy = OUTPUT_VARIABLE(1);
|
auto dldy = OUTPUT_VARIABLE(1);
|
||||||
|
|
||||||
const int iSize = (int) block.getIArguments()->size();
|
int iSize = (int) block.getIArguments()->size();
|
||||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||||
|
|
||||||
|
// optional use alpha nad beta
|
||||||
|
iSize = (int)block.getTArguments()->size();
|
||||||
|
|
||||||
|
double alpha = iSize > 0 ? T_ARG(0) : 1.0;
|
||||||
|
double beta = iSize > 1 ? T_ARG(1) : 0.0;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
In: x=[a,b], y=[b,c]
|
In: x=[a,b], y=[b,c]
|
||||||
tX tY tZ x y z dz dLdx dLdy
|
tX tY tZ x y z dz dLdx dLdy
|
||||||
|
@ -164,8 +174,8 @@ F F T [a,b] [b,c] [c,a] [c,a]
|
||||||
|
|
||||||
|
|
||||||
sd::ops::matmul op;
|
sd::ops::matmul op;
|
||||||
op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {});
|
op.execute({eps, y}, {dldx}, {alpha, beta}, {transZ, !transY, transX}, {});
|
||||||
op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {});
|
op.execute({x, eps}, {dldy}, {alpha, beta}, {!transX, transZ, transY}, {});
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,7 +32,7 @@ namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY) {
|
static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY, float alpha = 1.f, float beta = 0.f) {
|
||||||
|
|
||||||
// mkl works with following
|
// mkl works with following
|
||||||
// [M,K] x [K,N] = [M,N]
|
// [M,K] x [K,N] = [M,N]
|
||||||
|
@ -150,6 +150,12 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
|
||||||
|
|
||||||
// Create attributes (to handle alpha and beta if necessary)
|
// Create attributes (to handle alpha and beta if necessary)
|
||||||
dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0)
|
dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0)
|
||||||
|
if (alpha != 1.f) attr.set_output_scales(0, {alpha});
|
||||||
|
if (beta != 0.f) {
|
||||||
|
dnnl::post_ops po;
|
||||||
|
po.append_sum(beta);
|
||||||
|
attr.set_post_ops(po);
|
||||||
|
}
|
||||||
|
|
||||||
// operation primitive description
|
// operation primitive description
|
||||||
dnnl::matmul::desc op_desc(x_mkl_md, y_mkl_md, z_mkl_md);
|
dnnl::matmul::desc op_desc(x_mkl_md, y_mkl_md, z_mkl_md);
|
||||||
|
@ -224,11 +230,16 @@ PLATFORM_IMPL(matmul, ENGINE_CPU) {
|
||||||
if(x->isEmpty() || y->isEmpty())
|
if(x->isEmpty() || y->isEmpty())
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
||||||
const int iSize = (int) block.getIArguments()->size();
|
int iSize = (int) block.getIArguments()->size();
|
||||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||||
|
|
||||||
|
// optional use alpha nad beta
|
||||||
|
iSize = (int)block.getTArguments()->size();
|
||||||
|
float alpha = iSize > 0 ? T_ARG(0) : 1.0;
|
||||||
|
float beta = iSize > 1 ? T_ARG(1) : 0.0;
|
||||||
|
|
||||||
const int xRank = x->rankOf();
|
const int xRank = x->rankOf();
|
||||||
const int yRank = y->rankOf();
|
const int yRank = y->rankOf();
|
||||||
const int zRank = z->rankOf();
|
const int zRank = z->rankOf();
|
||||||
|
@ -265,7 +276,7 @@ PLATFORM_IMPL(matmul, ENGINE_CPU) {
|
||||||
}
|
}
|
||||||
// ******* end of input validation ******* //
|
// ******* end of input validation ******* //
|
||||||
|
|
||||||
matmulMKLDNN(x, y, z, transX, transY);
|
matmulMKLDNN(x, y, z, transX, transY, alpha, beta);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -276,14 +287,16 @@ PLATFORM_CHECK(matmul, ENGINE_CPU) {
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
auto z = INPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
const DataType xType = x->dataType();
|
const DataType xType = x->dataType();
|
||||||
const DataType yType = y->dataType();
|
const DataType yType = y->dataType();
|
||||||
const DataType zType = z->dataType();
|
const DataType zType = z->dataType();
|
||||||
|
|
||||||
|
float alpha = block.numT() > 0 ? T_ARG(0) : 1.0;
|
||||||
|
float beta = block.numT() > 1 ? T_ARG(1) : 0.0;
|
||||||
|
|
||||||
return block.isUseMKLDNN() && x->rankOf() < 3 &&
|
return !(z->ordering() == 'f' && beta != 0.f) && block.isUseMKLDNN() && x->rankOf() < 3 &&
|
||||||
(
|
(
|
||||||
(xType==DataType::FLOAT32 && yType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
(xType==DataType::FLOAT32 && yType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
||||||
(xType==DataType::HALF && yType==DataType::HALF && zType==DataType::FLOAT32) ||
|
(xType==DataType::HALF && yType==DataType::HALF && zType==DataType::FLOAT32) ||
|
||||||
|
|
|
@ -39,6 +39,97 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_matmul_ccc) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
|
||||||
|
z.assign(100.f);
|
||||||
|
e.assign(110.f);
|
||||||
|
x.assign(1.0f);
|
||||||
|
y.assign(1.0f);
|
||||||
|
|
||||||
|
sd::ops::matmul op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_matmul_fcf) {
|
||||||
|
auto x = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
auto e = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
auto z = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
|
||||||
|
z.assign(100.f);
|
||||||
|
e.assign(110.f);
|
||||||
|
x.assign(1.0f);
|
||||||
|
y.assign(1.0f);
|
||||||
|
|
||||||
|
sd::ops::matmul op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_matmul_cff) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
auto y = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
auto e = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
auto z = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
|
||||||
|
z.assign(100.f);
|
||||||
|
e.assign(110.f);
|
||||||
|
x.assign(1.0f);
|
||||||
|
y.assign(1.0f);
|
||||||
|
|
||||||
|
sd::ops::matmul op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_matmul_ccf) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
auto e = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
auto z = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
|
||||||
|
z.assign(100.f);
|
||||||
|
e.assign(110.f);
|
||||||
|
x.assign(1.0f);
|
||||||
|
y.assign(1.0f);
|
||||||
|
|
||||||
|
sd::ops::matmul op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_matmul_fff) {
|
||||||
|
auto x = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
auto y = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
auto e = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
auto z = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
|
||||||
|
z.assign(100.f);
|
||||||
|
e.assign(110.f);
|
||||||
|
x.assign(1.0f);
|
||||||
|
y.assign(1.0f);
|
||||||
|
|
||||||
|
sd::ops::matmul op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) {
|
TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) {
|
||||||
/*
|
/*
|
||||||
DynamicCustomOp op = DynamicCustomOp.builder("conv1d_bp")
|
DynamicCustomOp op = DynamicCustomOp.builder("conv1d_bp")
|
||||||
|
|
|
@ -19,11 +19,13 @@ package org.nd4j.linalg.api.blas.impl;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.linalg.api.blas.Level2;
|
import org.nd4j.linalg.api.blas.Level2;
|
||||||
import org.nd4j.linalg.api.blas.params.GemvParameters;
|
import org.nd4j.linalg.api.blas.params.GemvParameters;
|
||||||
|
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
||||||
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -57,6 +59,10 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 {
|
||||||
OpProfiler.getInstance().processBlasCall(false, A, X, Y);
|
OpProfiler.getInstance().processBlasCall(false, A, X, Y);
|
||||||
|
|
||||||
GemvParameters parameters = new GemvParameters(A, X, Y);
|
GemvParameters parameters = new GemvParameters(A, X, Y);
|
||||||
|
|
||||||
|
Nd4j.exec(new Mmul(A, X, Y, alpha, beta, MMulTranspose.builder().transposeA(false).build()));
|
||||||
|
|
||||||
|
/*
|
||||||
if (A.data().dataType() == DataType.DOUBLE) {
|
if (A.data().dataType() == DataType.DOUBLE) {
|
||||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, parameters.getA(), parameters.getX(),
|
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, parameters.getA(), parameters.getX(),
|
||||||
parameters.getY());
|
parameters.getY());
|
||||||
|
@ -86,7 +92,7 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 {
|
||||||
} else {
|
} else {
|
||||||
throw new ND4JIllegalStateException("Unsupported data type " + A.dataType());
|
throw new ND4JIllegalStateException("Unsupported data type " + A.dataType());
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
OpExecutionerUtil.checkForAny(Y);
|
OpExecutionerUtil.checkForAny(Y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,11 +19,13 @@ package org.nd4j.linalg.api.blas.impl;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.nd4j.linalg.api.blas.Level3;
|
import org.nd4j.linalg.api.blas.Level3;
|
||||||
import org.nd4j.linalg.api.blas.params.GemmParams;
|
import org.nd4j.linalg.api.blas.params.GemmParams;
|
||||||
|
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
||||||
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.profiler.OpProfiler;
|
import org.nd4j.linalg.profiler.OpProfiler;
|
||||||
|
@ -59,6 +61,9 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 {
|
||||||
|
|
||||||
GemmParams params = new GemmParams(A, B, C);
|
GemmParams params = new GemmParams(A, B, C);
|
||||||
|
|
||||||
|
Nd4j.exec(new Mmul(A, B, C, alpha, beta, MMulTranspose.builder().transposeA(false).transposeB(false).build()));
|
||||||
|
|
||||||
|
/*
|
||||||
int charOder = Order;
|
int charOder = Order;
|
||||||
if (A.data().dataType() == DataType.DOUBLE) {
|
if (A.data().dataType() == DataType.DOUBLE) {
|
||||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, params.getA(), params.getB(), params.getC());
|
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, params.getA(), params.getB(), params.getC());
|
||||||
|
@ -73,6 +78,7 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 {
|
||||||
hgemm(Order, params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(), 1.0f,
|
hgemm(Order, params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(), 1.0f,
|
||||||
params.getA(), params.getLda(), params.getB(), params.getLdb(), 0, C, params.getLdc());
|
params.getA(), params.getLda(), params.getB(), params.getLdb(), 0, C, params.getLdc());
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
OpExecutionerUtil.checkForAny(C);
|
OpExecutionerUtil.checkForAny(C);
|
||||||
}
|
}
|
||||||
|
@ -85,6 +91,9 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 {
|
||||||
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
||||||
OpProfiler.getInstance().processBlasCall(true, A, B, C);
|
OpProfiler.getInstance().processBlasCall(true, A, B, C);
|
||||||
|
|
||||||
|
Nd4j.exec(new Mmul(A, B, C, alpha, beta, MMulTranspose.builder().transposeA(transposeA).transposeB(transposeB).build()));
|
||||||
|
|
||||||
|
/*
|
||||||
GemmParams params = new GemmParams(A, B, C, transposeA, transposeB);
|
GemmParams params = new GemmParams(A, B, C, transposeA, transposeB);
|
||||||
if (A.data().dataType() == DataType.DOUBLE) {
|
if (A.data().dataType() == DataType.DOUBLE) {
|
||||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, params.getA(), params.getB(), C);
|
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, params.getA(), params.getB(), C);
|
||||||
|
@ -102,7 +111,7 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 {
|
||||||
(float) alpha, params.getA(), params.getLda(), params.getB(), params.getLdb(), (float) beta,
|
(float) alpha, params.getA(), params.getLda(), params.getB(), params.getLdb(), (float) beta,
|
||||||
C, params.getLdc());
|
C, params.getLdc());
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
OpExecutionerUtil.checkForAny(C);
|
OpExecutionerUtil.checkForAny(C);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2866,16 +2866,22 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray mmul(INDArray other) {
|
public INDArray mmul(INDArray other, char resultOrder) {
|
||||||
|
Preconditions.checkArgument(resultOrder == 'c' || resultOrder == 'f', "Order must be either 'c' or 'f', but [" + resultOrder + "] was given");
|
||||||
Preconditions.checkState(this.dataType() == other.dataType(), "Matrix multiplication: arrays must have same dtype: %s vs. %s", this.dataType(), other.dataType());
|
Preconditions.checkState(this.dataType() == other.dataType(), "Matrix multiplication: arrays must have same dtype: %s vs. %s", this.dataType(), other.dataType());
|
||||||
// FIXME: for 1D case, we probably want vector output here?
|
// FIXME: add support for 3D+ here?
|
||||||
long[] shape = {rows(), other.rank() == 1 ? 1 : other.columns()};
|
long[] shape = other.rank() == 1 ? new long[]{rows()} : new long[]{rows(), other.columns()};
|
||||||
INDArray result = createUninitialized(this.dataType(), shape, 'f');
|
INDArray result = createUninitialized(this.dataType(), shape, resultOrder);
|
||||||
if (result.isScalar())
|
if (result.isScalar())
|
||||||
return Nd4j.scalar(this.dataType(), Nd4j.getBlasWrapper().dot(this, other)).reshape(1, 1);
|
return Nd4j.scalar(this.dataType(), Nd4j.getBlasWrapper().dot(this, other)).reshape(1, 1);
|
||||||
return mmuli(other, result);
|
return mmuli(other, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray mmul(INDArray other) {
|
||||||
|
return mmul(other, (this.ordering() == 'f' && other.ordering() == 'f' && other.rank() != 1) ? 'f' : 'c');
|
||||||
|
}
|
||||||
|
|
||||||
protected INDArray create(int[] shape, char ordering) {
|
protected INDArray create(int[] shape, char ordering) {
|
||||||
return Nd4j.create(shape, ordering);
|
return Nd4j.create(shape, ordering);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1232,6 +1232,14 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
INDArray mmul(INDArray other);
|
INDArray mmul(INDArray other);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform a copy matrix multiplication
|
||||||
|
* @param other other the other matrix to perform matrix multiply with
|
||||||
|
* @param resultOrder either C or F order for result array
|
||||||
|
* @return the result of the matrix multiplication
|
||||||
|
*/
|
||||||
|
INDArray mmul(INDArray other, char resultOrder);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert this ndarray to a 2d double matrix.
|
* Convert this ndarray to a 2d double matrix.
|
||||||
* Note that THIS SHOULD NOT BE USED FOR SPEED.
|
* Note that THIS SHOULD NOT BE USED FOR SPEED.
|
||||||
|
|
|
@ -44,6 +44,8 @@ import java.util.*;
|
||||||
public class Mmul extends DynamicCustomOp {
|
public class Mmul extends DynamicCustomOp {
|
||||||
|
|
||||||
protected MMulTranspose mt;
|
protected MMulTranspose mt;
|
||||||
|
protected double alpha = 1.0;
|
||||||
|
protected double beta = 0.0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -59,6 +61,7 @@ public class Mmul extends DynamicCustomOp {
|
||||||
super(null,sameDiff,new SDVariable[]{i_v1,i_v2});
|
super(null,sameDiff,new SDVariable[]{i_v1,i_v2});
|
||||||
this.mt = mt;
|
this.mt = mt;
|
||||||
addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), ArrayUtil.fromBoolean(mt.isTransposeB()), ArrayUtil.fromBoolean(mt.isTransposeResult()));
|
addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), ArrayUtil.fromBoolean(mt.isTransposeB()), ArrayUtil.fromBoolean(mt.isTransposeResult()));
|
||||||
|
addTArgument(alpha, beta);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,6 +77,30 @@ public class Mmul extends DynamicCustomOp {
|
||||||
this(sameDiff,i_v1,i_v2,MMulTranspose.allFalse());
|
this(sameDiff,i_v1,i_v2,MMulTranspose.allFalse());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Mmul(INDArray x,
|
||||||
|
INDArray y,
|
||||||
|
INDArray z,
|
||||||
|
double alpha,
|
||||||
|
double beta,
|
||||||
|
MMulTranspose mt) {
|
||||||
|
addInputArgument(x, y);
|
||||||
|
|
||||||
|
if (z != null)
|
||||||
|
addOutputArgument(z);
|
||||||
|
|
||||||
|
if (mt != null) {
|
||||||
|
this.mt = mt;
|
||||||
|
addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()),
|
||||||
|
ArrayUtil.fromBoolean(mt.isTransposeB()),
|
||||||
|
ArrayUtil.fromBoolean(mt.isTransposeResult()));
|
||||||
|
}
|
||||||
|
|
||||||
|
this.alpha = alpha;
|
||||||
|
this.beta = beta;
|
||||||
|
|
||||||
|
addTArgument(alpha, beta);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param x
|
* @param x
|
||||||
|
@ -84,25 +111,30 @@ public class Mmul extends DynamicCustomOp {
|
||||||
INDArray y,
|
INDArray y,
|
||||||
INDArray z,
|
INDArray z,
|
||||||
MMulTranspose mt) {
|
MMulTranspose mt) {
|
||||||
super(null, new INDArray[]{x, y}, z == null ? null : new INDArray[]{z});
|
this(x, y, z, 1.0, 0.0, mt);
|
||||||
if (mt != null) {
|
|
||||||
this.mt = mt;
|
|
||||||
addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()),
|
|
||||||
ArrayUtil.fromBoolean(mt.isTransposeB()),
|
|
||||||
ArrayUtil.fromBoolean(mt.isTransposeResult()));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Mmul(INDArray x, INDArray y, boolean transposeX, boolean transposeY, boolean transposeZ) {
|
public Mmul(INDArray x, INDArray y, boolean transposeX, boolean transposeY, boolean transposeZ) {
|
||||||
|
this(x, y, 1.0, 0.0, transposeX, transposeY, transposeZ);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Mmul(INDArray x, INDArray y, double alpha, double beta, boolean transposeX, boolean transposeY, boolean transposeZ) {
|
||||||
addInputArgument(x, y);
|
addInputArgument(x, y);
|
||||||
addIArgument(ArrayUtil.fromBoolean(transposeX),
|
addIArgument(ArrayUtil.fromBoolean(transposeX),
|
||||||
ArrayUtil.fromBoolean(transposeY),
|
ArrayUtil.fromBoolean(transposeY),
|
||||||
ArrayUtil.fromBoolean(transposeZ));
|
ArrayUtil.fromBoolean(transposeZ));
|
||||||
mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build();
|
mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build();
|
||||||
|
addTArgument(alpha, beta);
|
||||||
|
this.alpha = alpha;
|
||||||
|
this.beta = beta;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Mmul(INDArray x, INDArray y, double alpha, double beta) {
|
||||||
|
this(x,y,null, alpha, beta,null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Mmul(INDArray x, INDArray y) {
|
public Mmul(INDArray x, INDArray y) {
|
||||||
this(x,y,null,null);
|
this(x, y, 1.0, 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Mmul(SameDiff sameDiff, SDVariable x, SDVariable y, boolean transposeX, boolean transposeY,
|
public Mmul(SameDiff sameDiff, SDVariable x, SDVariable y, boolean transposeX, boolean transposeY,
|
||||||
|
@ -111,6 +143,8 @@ public class Mmul extends DynamicCustomOp {
|
||||||
addIArgument(ArrayUtil.fromBoolean(transposeX),
|
addIArgument(ArrayUtil.fromBoolean(transposeX),
|
||||||
ArrayUtil.fromBoolean(transposeY),
|
ArrayUtil.fromBoolean(transposeY),
|
||||||
ArrayUtil.fromBoolean(transposeZ));
|
ArrayUtil.fromBoolean(transposeZ));
|
||||||
|
|
||||||
|
addTArgument(alpha, beta);
|
||||||
mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build();
|
mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -791,7 +791,7 @@ public class Nd4j {
|
||||||
boolean transposeB) {
|
boolean transposeB) {
|
||||||
long cRows = (transposeA ? a.columns() : a.rows());
|
long cRows = (transposeA ? a.columns() : a.rows());
|
||||||
long cCols = (transposeB ? b.rows() : b.columns());
|
long cCols = (transposeB ? b.rows() : b.columns());
|
||||||
INDArray c = Nd4j.createUninitialized(a.dataType(), new long[] {cRows, cCols}, 'f');
|
INDArray c = Nd4j.createUninitialized(a.dataType(), new long[] {cRows, cCols}, a.ordering() == 'c' && b.ordering() == 'c' ? 'c' : 'f');
|
||||||
return gemm(a, b, c, transposeA, transposeB, 1.0, 0.0);
|
return gemm(a, b, c, transposeA, transposeB, 1.0, 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -817,12 +817,9 @@ public class Nd4j {
|
||||||
boolean transposeB,
|
boolean transposeB,
|
||||||
double alpha,
|
double alpha,
|
||||||
double beta) {
|
double beta) {
|
||||||
//Note: some views have non-zero offset but 'default' strides (these are OK). And a 'c' order vector such as [10,1] is OK - same buffer as an 'f' order vector with same shape
|
Preconditions.checkArgument(c.elementWiseStride() == 1, "Nd4j.gemm() C array should NOT be a view");
|
||||||
Preconditions.checkState(c.length() == 1 || c.ordering() == 'f' && Shape.hasDefaultStridesForShape(c) ||
|
|
||||||
c.isVectorOrScalar() && c.elementWiseStride() == 1,
|
Nd4j.exec(new Mmul(a, b, c, alpha, beta, MMulTranspose.builder().transposeA(transposeA).transposeB(transposeB).build()));
|
||||||
"C (result) array is not F order or is a view. Nd4j.gemm requires the result array to be F order " +
|
|
||||||
"and not a view. C (result) array: [%ndSInfo]", c);
|
|
||||||
getBlasWrapper().level3().gemm(a, b, c, transposeA, transposeB, alpha, beta);
|
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,11 @@ public class CheckpointListenerTest extends BaseNd4jTest {
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000L;
|
||||||
|
}
|
||||||
|
|
||||||
public static SameDiff getModel(){
|
public static SameDiff getModel(){
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
@ -151,7 +156,7 @@ public class CheckpointListenerTest extends BaseNd4jTest {
|
||||||
|
|
||||||
CheckpointListener l = new CheckpointListener.Builder(dir)
|
CheckpointListener l = new CheckpointListener.Builder(dir)
|
||||||
.keepLast(2)
|
.keepLast(2)
|
||||||
.saveEvery(1, TimeUnit.SECONDS)
|
.saveEvery(4, TimeUnit.SECONDS)
|
||||||
.build();
|
.build();
|
||||||
sd.setListeners(l);
|
sd.setListeners(l);
|
||||||
|
|
||||||
|
@ -159,7 +164,7 @@ public class CheckpointListenerTest extends BaseNd4jTest {
|
||||||
|
|
||||||
for(int i=0; i<5; i++ ){ //10 iterations total
|
for(int i=0; i<5; i++ ){ //10 iterations total
|
||||||
sd.fit(iter, 1);
|
sd.fit(iter, 1);
|
||||||
Thread.sleep(1000);
|
Thread.sleep(5000);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Expect models saved at iterations: 10, 20, 30, 40
|
//Expect models saved at iterations: 10, 20, 30, 40
|
||||||
|
|
|
@ -6192,24 +6192,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertEquals(exp, output);
|
assertEquals(exp, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testVectorGemv() {
|
|
||||||
val vectorL = Nd4j.create(new float[]{1, 2, 3}, new long[]{3, 1});
|
|
||||||
val vectorN = Nd4j.create(new float[]{1, 2, 3}, new long[]{3});
|
|
||||||
val matrix = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, new long[] {3, 3});
|
|
||||||
|
|
||||||
// log.info("vectorN: {}", vectorN);
|
|
||||||
// log.info("vectorL: {}", vectorL);
|
|
||||||
|
|
||||||
val outN = matrix.mmul(vectorN);
|
|
||||||
val outL = matrix.mmul(vectorL);
|
|
||||||
|
|
||||||
assertEquals(outL, outN.reshape(3,1));
|
|
||||||
|
|
||||||
assertEquals(1, outN.rank());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMatrixReshape() {
|
public void testMatrixReshape() {
|
||||||
val matrix = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, new long[] {3, 3});
|
val matrix = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, new long[] {3, 3});
|
||||||
|
|
|
@ -60,7 +60,7 @@ public class Level3Test extends BaseNd4jTest {
|
||||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
|
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
|
||||||
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
|
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
|
||||||
|
|
||||||
INDArray array3 = array1.mmul(array2);
|
INDArray array3 = array1.mmul(array2, Nd4j.createUninitialized(new long[]{10, 10}, 'f'));
|
||||||
|
|
||||||
|
|
||||||
//System.out.println("Array3: " + Arrays.toString(array3.data().asFloat()));
|
//System.out.println("Array3: " + Arrays.toString(array3.data().asFloat()));
|
||||||
|
|
|
@ -83,7 +83,7 @@ public class DataTypeValidationTests extends BaseNd4jTest {
|
||||||
/**
|
/**
|
||||||
* Testing level2 blas
|
* Testing level2 blas
|
||||||
*/
|
*/
|
||||||
@Test(expected = ND4JIllegalStateException.class)
|
@Test(expected = RuntimeException.class)
|
||||||
public void testBlasValidation2() {
|
public void testBlasValidation2() {
|
||||||
INDArray a = Nd4j.create(100, 10);
|
INDArray a = Nd4j.create(100, 10);
|
||||||
INDArray x = Nd4j.create(100);
|
INDArray x = Nd4j.create(100);
|
||||||
|
|
|
@ -83,22 +83,7 @@ public class BlasTests extends BaseNd4jTest {
|
||||||
try {
|
try {
|
||||||
Nd4j.gemm(a, b, view, false, false, 1.0, 0.0);
|
Nd4j.gemm(a, b, view, false, false, 1.0, 0.0);
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (IllegalStateException e) {
|
} catch (IllegalArgumentException e) {
|
||||||
assertTrue(e.getMessage().contains("view"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testGemmInvalid2() {
|
|
||||||
final INDArray a = Nd4j.rand(4, 3);
|
|
||||||
final INDArray b = Nd4j.rand(4, 5);
|
|
||||||
|
|
||||||
final INDArray target = Nd4j.zeros(3, 5, 'c');
|
|
||||||
|
|
||||||
try {
|
|
||||||
Nd4j.gemm(a, b, target, true, false, 1.0, 0.0);
|
|
||||||
fail("Expected exception");
|
|
||||||
} catch (IllegalStateException e) {
|
|
||||||
assertTrue(e.getMessage().contains("view"));
|
assertTrue(e.getMessage().contains("view"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -114,7 +99,7 @@ public class BlasTests extends BaseNd4jTest {
|
||||||
try {
|
try {
|
||||||
Nd4j.gemm(a, b, view, true, false, 1.0, 0.0);
|
Nd4j.gemm(a, b, view, true, false, 1.0, 0.0);
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (IllegalStateException e) {
|
} catch (IllegalArgumentException e) {
|
||||||
assertTrue(e.getMessage().contains("view"));
|
assertTrue(e.getMessage().contains("view"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue