Merge pull request #8703 from KonduitAI/master

Update master
master
Alex Black 2020-02-13 16:14:24 +11:00 committed by GitHub
commit 4b46aaedd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
67 changed files with 1151 additions and 267 deletions

View File

@ -226,7 +226,7 @@ public class PythonExecutioner {
private static void throwIfExecutionFailed() throws PythonException{ private static void throwIfExecutionFailed() throws PythonException{
PythonObject ex = getVariable(PYTHON_EXCEPTION_KEY); PythonObject ex = getVariable(PYTHON_EXCEPTION_KEY);
if (ex != null && !ex.toString().isEmpty()){ if (ex != null && !ex.isNone() && !ex.toString().isEmpty()) {
setVariable(PYTHON_EXCEPTION_KEY, new PythonObject("")); setVariable(PYTHON_EXCEPTION_KEY, new PythonObject(""));
throw new PythonException(ex); throw new PythonException(ex);
} }

View File

@ -583,7 +583,9 @@ public class PythonObject {
} }
} }
public boolean isNone() { public boolean isNone() {
return nativePythonObject == null; return (nativePythonObject == null)||
(toString().equals("None") && Python.type(this).toString().equals("<class 'NoneType'>"));
} }
} }

View File

@ -322,5 +322,16 @@ public class TestPythonExecutioner {
Python.setMainContext(); Python.setMainContext();
} }
@Test
public void testIsNone(){
PythonObject d = Python.dict();
PythonObject none = d.attr("get").call("x");
Assert.assertTrue(none.isNone());
d.set(new PythonObject("x"), new PythonObject("y"));
PythonObject notNone = d.attr("get").call("x");
Assert.assertFalse(notNone.isNone());
Assert.assertEquals("y", notNone.toString());
}
} }

View File

@ -414,7 +414,11 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
INDArray l = TestUtils.randomOneHot(mb, 3); INDArray l = TestUtils.randomOneHot(mb, 3);
INDArray lm = TestUtils.randomBernoulli(mb, 1); INDArray lm = TestUtils.randomBernoulli(mb, 1);
assertTrue(lm.sumNumber().intValue() > 0); int attempts = 0;
while(attempts++ < 1000 && lm.sumNumber().intValue() == 0){
lm = TestUtils.randomBernoulli(mb, 1);
}
assertTrue("Could not generate non-zero mask after " + attempts + " attempts", lm.sumNumber().intValue() > 0);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f) boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f)
.labels(l).labelMask(lm)); .labels(l).labelMask(lm));
@ -467,7 +471,11 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
INDArray l = TestUtils.randomOneHot(mb, 3); INDArray l = TestUtils.randomOneHot(mb, 3);
INDArray lm = TestUtils.randomBernoulli(mb, 1); INDArray lm = TestUtils.randomBernoulli(mb, 1);
assertTrue(lm.sumNumber().intValue() > 0); int attempts = 0;
while(attempts++ < 1000 && lm.sumNumber().intValue() == 0){
lm = TestUtils.randomBernoulli(mb, 1);
}
assertTrue("Could not generate non-zero mask after " + attempts + " attempts", lm.sumNumber().intValue() > 0);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{f}) boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{f})
.labels(new INDArray[]{l}).labelMask(new INDArray[]{lm})); .labels(new INDArray[]{l}).labelMask(new INDArray[]{lm}));

View File

@ -67,7 +67,9 @@ public class SparkSequenceVectorsTest extends BaseDL4JTest {
} }
} }
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("SeqVecTests"); SparkConf sparkConf = new SparkConf().setMaster("local[8]")
.set("spark.driver.host", "localhost")
.setAppName("SeqVecTests");
sc = new JavaSparkContext(sparkConf); sc = new JavaSparkContext(sparkConf);
} }

View File

@ -61,7 +61,9 @@ public class SparkWord2VecTest extends BaseDL4JTest {
sentences.add("one another sentence"); sentences.add("one another sentence");
} }
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("SeqVecTests"); SparkConf sparkConf = new SparkConf().setMaster("local[8]")
.set("spark.driver.host", "localhost")
.setAppName("SeqVecTests");
sc = new JavaSparkContext(sparkConf); sc = new JavaSparkContext(sparkConf);
} }

View File

@ -56,7 +56,9 @@ public class Word2VecTest {
@Test @Test
public void testConcepts() throws Exception { public void testConcepts() throws Exception {
// These are all default values for word2vec // These are all default values for word2vec
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest"); SparkConf sparkConf = new SparkConf().setMaster("local[8]")
.set("spark.driver.host", "localhost")
.setAppName("sparktest");
// Set SparkContext // Set SparkContext
JavaSparkContext sc = new JavaSparkContext(sparkConf); JavaSparkContext sc = new JavaSparkContext(sparkConf);
@ -156,6 +158,7 @@ public class Word2VecTest {
@Test @Test
public void testSparkW2VonBiggerCorpus() throws Exception { public void testSparkW2VonBiggerCorpus() throws Exception {
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest") SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest")
.set("spark.driver.host", "localhost")
.set("spark.driver.maxResultSize", "4g").set("spark.driver.memory", "8g") .set("spark.driver.maxResultSize", "4g").set("spark.driver.memory", "8g")
.set("spark.executor.memory", "8g"); .set("spark.executor.memory", "8g");

View File

@ -63,7 +63,7 @@ public class TextPipelineTest extends BaseSparkTest {
@Before @Before
public void before() throws Exception { public void before() throws Exception {
conf = new SparkConf().setMaster("local[4]").setAppName("sparktest"); conf = new SparkConf().setMaster("local[4]").setAppName("sparktest").set("spark.driver.host", "localhost");
// All the avaliable options. These are default values // All the avaliable options. These are default values
word2vec = new Word2Vec.Builder().minWordFrequency(1).setNGrams(1) word2vec = new Word2Vec.Builder().minWordFrequency(1).setNGrams(1)

View File

@ -85,7 +85,8 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable
if (sc != null) if (sc != null)
return sc; return sc;
// set to test mode // set to test mode
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").set("spark.driver.host", "localhost").setAppName("sparktest"); SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]")
.set("spark.driver.host", "localhost").setAppName("sparktest");
sc = new JavaSparkContext(sparkConf); sc = new JavaSparkContext(sparkConf);

View File

@ -59,7 +59,9 @@ public class BaseSparkKryoTest extends BaseSparkTest {
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").setAppName("sparktest"); SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]")
.setAppName("sparktest")
.set("spark.driver.host", "localhost");
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
sparkConf.set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator"); sparkConf.set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator");

View File

@ -89,7 +89,8 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable
if (sc != null) if (sc != null)
return sc; return sc;
// set to test mode // set to test mode
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").set("spark.driver.host", "localhost").setAppName("sparktest"); SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]")
.set("spark.driver.host", "localhost").setAppName("sparktest");
sc = new JavaSparkContext(sparkConf); sc = new JavaSparkContext(sparkConf);

View File

@ -72,8 +72,9 @@ public class TestKryoWarning {
@Ignore @Ignore
public void testKryoMessageMLNIncorrectConfig() { public void testKryoMessageMLNIncorrectConfig() {
//Should print warning message //Should print warning message
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest").set("spark.serializer", SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest")
"org.apache.spark.serializer.KryoSerializer"); .set("spark.driver.host", "localhost")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
doTestMLN(sparkConf); doTestMLN(sparkConf);
} }
@ -83,6 +84,7 @@ public class TestKryoWarning {
public void testKryoMessageMLNCorrectConfigKryo() { public void testKryoMessageMLNCorrectConfigKryo() {
//Should NOT print warning message //Should NOT print warning message
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest") SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest")
.set("spark.driver.host", "localhost")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator"); .set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator");
@ -93,7 +95,9 @@ public class TestKryoWarning {
@Ignore @Ignore
public void testKryoMessageMLNCorrectConfigNoKryo() { public void testKryoMessageMLNCorrectConfigNoKryo() {
//Should NOT print warning message //Should NOT print warning message
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest"); SparkConf sparkConf = new SparkConf().setMaster("local[*]")
.set("spark.driver.host", "localhost")
.setAppName("sparktest");
doTestMLN(sparkConf); doTestMLN(sparkConf);
} }
@ -104,8 +108,9 @@ public class TestKryoWarning {
@Ignore @Ignore
public void testKryoMessageCGIncorrectConfig() { public void testKryoMessageCGIncorrectConfig() {
//Should print warning message //Should print warning message
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest").set("spark.serializer", SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest")
"org.apache.spark.serializer.KryoSerializer"); .set("spark.driver.host", "localhost")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
doTestCG(sparkConf); doTestCG(sparkConf);
} }
@ -115,6 +120,7 @@ public class TestKryoWarning {
public void testKryoMessageCGCorrectConfigKryo() { public void testKryoMessageCGCorrectConfigKryo() {
//Should NOT print warning message //Should NOT print warning message
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest") SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest")
.set("spark.driver.host", "localhost")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator"); .set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator");
@ -125,7 +131,9 @@ public class TestKryoWarning {
@Ignore @Ignore
public void testKryoMessageCGCorrectConfigNoKryo() { public void testKryoMessageCGCorrectConfigNoKryo() {
//Should NOT print warning message //Should NOT print warning message
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest"); SparkConf sparkConf = new SparkConf().setMaster("local[*]")
.set("spark.driver.host", "localhost")
.setAppName("sparktest");
doTestCG(sparkConf); doTestCG(sparkConf);
} }

View File

@ -138,6 +138,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
SparkConf sparkConf = new SparkConf(); SparkConf sparkConf = new SparkConf();
sparkConf.setMaster("local[" + nWorkers + "]"); sparkConf.setMaster("local[" + nWorkers + "]");
sparkConf.setAppName("Test"); sparkConf.setAppName("Test");
sparkConf.set("spark.driver.host", "localhost");
JavaSparkContext sc = new JavaSparkContext(sparkConf); JavaSparkContext sc = new JavaSparkContext(sparkConf);
return sc; return sc;

View File

@ -58,7 +58,7 @@ public class ExportSupportTest {
} }
private void assertSupported(SparkConf conf) throws IOException { private void assertSupported(SparkConf conf) throws IOException {
JavaSparkContext sc = new JavaSparkContext(conf.setAppName("Test")); JavaSparkContext sc = new JavaSparkContext(conf.setAppName("Test").set("spark.driver.host", "localhost"));
try { try {
assertTrue(ExportSupport.exportSupported(sc)); assertTrue(ExportSupport.exportSupported(sc));
} finally { } finally {

View File

@ -25,6 +25,7 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.spark.BaseSparkTest;
import org.deeplearning4j.spark.api.Repartition; import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.stats.CommonSparkTrainingStats; import org.deeplearning4j.spark.api.stats.CommonSparkTrainingStats;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats; import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
@ -50,17 +51,13 @@ import static org.junit.Assert.*;
/** /**
* Created by Alex on 17/06/2016. * Created by Alex on 17/06/2016.
*/ */
public class TestTrainingStatsCollection { public class TestTrainingStatsCollection extends BaseSparkTest {
@Test @Test
public void testStatsCollection() throws Exception { public void testStatsCollection() throws Exception {
int nWorkers = 4; int nWorkers = numExecutors();
SparkConf sparkConf = new SparkConf(); JavaSparkContext sc = getContext();
sparkConf.setMaster("local[" + nWorkers + "]");
sparkConf.setAppName("Test");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
try { try {

View File

@ -1,4 +1,4 @@
{ {
"configurations": [ "configurations": [
{ {
"name": "x64-Debug", "name": "x64-Debug",

View File

@ -32,9 +32,10 @@ template <typename X, typename Y, typename Z>
template<typename OpType> template<typename OpType>
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer()); const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const Y* y = reinterpret_cast<Y*>(yArr.getBuffer()); const Y* y = reinterpret_cast<Y*>(yArr.getBuffer());
Z* z = reinterpret_cast<Z*>(zArr.getBuffer()); Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
const auto xShapeInfo = xArr.getShapeInfo(); const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo(); const auto yShapeInfo = yArr.getShapeInfo();
@ -44,8 +45,26 @@ void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr
const int yRank = yArr.rankOf(); const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf(); const int zRank = zArr.rankOf();
const Nd4jLong zLen = zArr.lengthOf(); bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() && 1 == yRank &&
1 == yArr.ews() && 'c' == yArr.ordering() &&
1 == zArr.ews() && 'c' == zArr.ordering());
if (bSpecialCase) {
auto yLen = (uint32_t)yArr.lengthOf();
auto func = PRAGMA_THREADS_FOR{
for (uint32_t i = start; i < stop; i++) {
auto rZ = z + (i * yLen);
auto v = x[i];
for (uint32_t j = 0; j < yLen; j++) {
rZ[j] = OpType::op(v, y[j]);
}
}
};
samediff::Threads::parallel_tad(func, 0, xArr.lengthOf());
return;
}
const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());

View File

@ -80,8 +80,8 @@ namespace nd4j {
"resize_area: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); "resize_area: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
REQUIRE_TRUE(block.numI() <= 1, 0, REQUIRE_TRUE(block.numI() <= 1, 0,
"resize_area: Resize params already given by the second param. Int params are expensive."); "resize_area: Resize params already given by the second param. Int params are expensive.");
width = newImageSize->e<int>(0); width = newImageSize->e<int>(1);
height = newImageSize->e<int>(1); height = newImageSize->e<int>(0);
} }
else { else {
REQUIRE_TRUE(block.numI() == 2, 0, "resize_area: Resize params ommited as pair ints nor int tensor."); REQUIRE_TRUE(block.numI() == 2, 0, "resize_area: Resize params ommited as pair ints nor int tensor.");
@ -95,13 +95,13 @@ namespace nd4j {
outputShape[0] = inRank; outputShape[0] = inRank;
if (inRank == 4) { if (inRank == 4) {
outputShape[1] = in[1]; outputShape[1] = in[1];
outputShape[2] = width; outputShape[2] = height;
outputShape[3] = height; outputShape[3] = width;
outputShape[4] = in[4]; outputShape[4] = in[4];
} }
else { else {
outputShape[1] = width; outputShape[1] = height;
outputShape[2] = height; outputShape[2] = width;
outputShape[3] = in[3]; outputShape[3] = in[3];
} }
ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in)); ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in));

View File

@ -1116,7 +1116,7 @@ namespace helpers {
err = cudaMemcpyAsync(pSt, &st, sizeof(ImageResizerState), cudaMemcpyHostToDevice, *stream); err = cudaMemcpyAsync(pSt, &st, sizeof(ImageResizerState), cudaMemcpyHostToDevice, *stream);
ScaleCache<T>* cachePool; ScaleCache<T>* cachePool;
err = cudaMalloc(&cachePool, sizeof(ScaleCache<T>) * st.batchSize * st.outWidth * st.outHeight); err = cudaMalloc(&cachePool, sizeof(ScaleCache<T>) * st.batchSize * st.outWidth * st.outHeight);
resizeAreaKernel<T><<<128, 4, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->getSpecialShapeInfo(), outputPtr, resizeAreaKernel<T><<<128, 2, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->getSpecialShapeInfo(), outputPtr,
output->specialShapeInfo(), cachePool); output->specialShapeInfo(), cachePool);
err = cudaStreamSynchronize(*stream); err = cudaStreamSynchronize(*stream);
err = cudaFree(cachePool); err = cudaFree(cachePool);

View File

@ -1520,6 +1520,65 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test13) {
delete results; delete results;
} }
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test14) {
NDArray input = NDArrayFactory::create<int>('c', {1, 5, 5, 1}, {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25
});
auto size = NDArrayFactory::create<int>({8, 7});
NDArray expected = NDArrayFactory::create<float>('c', {1, 8, 7, 1}, {
1.f, 1.6f , 2.1999993f, 2.9999995f , 3.8f , 4.399997f, 5.f , 2.9999995f , 3.5999997f , 4.199999f,
4.9999995f, 5.8f , 6.3999963f , 7.f , 5.999999f , 6.6f, 7.1999984f , 7.9999995f , 8.8f,
9.399994f, 10.f , 10.f, 10.6f , 11.199998f, 12.f, 12.8f, 13.399992f , 14.f, 12.f , 12.599999f,
13.199998f , 13.999998f , 14.800002f , 15.399991f , 16.f , 15.999999f , 16.599998f , 17.199995f,
18.f , 18.800003f , 19.399986f , 20.000002f , 19.f , 19.599998f , 20.199997f ,
20.999998f , 21.800003f , 22.399984f , 23.000002f , 20.999998f ,
21.599998f , 22.199995f , 22.999998f , 23.800001f , 24.399984f ,
25.f
}); //input.linspace(1);
// auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op;
auto results = op.evaluate({&input, &size}, {}, {false});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 8x7");
// expected.printBuffer("Area Expect for 8x7");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test15) {
NDArray input = NDArrayFactory::create<int>('c', {1, 5, 5, 1}, {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25
});
//auto size = NDArrayFactory::create<int>({8, 7});
NDArray expected = NDArrayFactory::create<float>('c', {1, 8, 7, 1}, {
1.f, 1.6f , 2.1999993f, 2.9999995f , 3.8f , 4.399997f, 5.f , 2.9999995f , 3.5999997f , 4.199999f,
4.9999995f, 5.8f , 6.3999963f , 7.f , 5.999999f , 6.6f, 7.1999984f , 7.9999995f , 8.8f,
9.399994f, 10.f , 10.f, 10.6f , 11.199998f, 12.f, 12.8f, 13.399992f , 14.f, 12.f , 12.599999f,
13.199998f , 13.999998f , 14.800002f , 15.399991f , 16.f , 15.999999f , 16.599998f , 17.199995f,
18.f , 18.800003f , 19.399986f , 20.000002f , 19.f , 19.599998f , 20.199997f ,
20.999998f , 21.800003f , 22.399984f , 23.000002f , 20.999998f , 21.599998f , 22.199995f ,
22.999998f , 23.800001f , 24.399984f , 25.f
});
nd4j::ops::resize_area op;
auto results = op.evaluate({&input}, {}, {8, 7}, {false});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 8x7");
// expected.printBuffer("Area Expect for 8x7");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, summaryStatsData_test1) { TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {

View File

@ -532,3 +532,24 @@ TEST_F(DeclarableOpsTests14, repeat_5) {
delete result; delete result;
} }
/////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_SpecialCaseTest) {
auto y = NDArray('c', { 3 }, nd4j::DataType::FLOAT32);
auto x = NDArray('c', { 5, 2, 1 }, nd4j::DataType::FLOAT32);
auto e = NDArray('c', { 5, 2, 3 }, { 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11. }, nd4j::DataType::FLOAT32);
y.assign(1.0);
x.linspace(1.0);
nd4j::ops::add op;
auto result = op.evaluate({ &x, &y });
ASSERT_EQ(Status::OK(), result->status());
auto res = *result->at(0);
ASSERT_EQ(e, res);
delete result;
}

View File

@ -294,6 +294,11 @@ public class DifferentialFunctionFactory {
return new ZerosLike(name, sameDiff(), input).outputVariable(); return new ZerosLike(name, sameDiff(), input).outputVariable();
} }
public SDVariable zerosLike(String name, SDVariable input, DataType dataType) {
validateDifferentialFunctionsameDiff(input);
return new ZerosLike(name, sameDiff(), input, dataType).outputVariable();
}
public SDVariable create(String name, SDVariable shape, boolean initialize, DataType dataType) { public SDVariable create(String name, SDVariable shape, boolean initialize, DataType dataType) {
return create(name, shape, 'c', initialize, dataType); return create(name, shape, 'c', initialize, dataType);
} }
@ -1751,12 +1756,12 @@ public class DifferentialFunctionFactory {
return new SoftmaxCrossEntropyLossBp(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariables(); return new SoftmaxCrossEntropyLossBp(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariables();
} }
public SDVariable lossSoftmaxCrossEntropyWithLogits(SDVariable labels, SDVariable logits, SDVariable weights, int classDim) { public SDVariable lossSoftmaxCrossEntropyWithLogits(SDVariable labels, SDVariable logits, int classDim) {
return new SoftmaxCrossEntropyWithLogitsLoss(sameDiff(), logits, weights, labels, classDim).outputVariable(); return new SoftmaxCrossEntropyWithLogitsLoss(sameDiff(), logits, labels, classDim).outputVariable();
} }
public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVariable logits, SDVariable weights, int classDim) { public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVariable logits, int classDim) {
return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff(), logits, weights, labels, classDim).outputVariables(); return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff(), logits, labels, classDim).outputVariables();
} }
public SDVariable lossSparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels){ public SDVariable lossSparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels){
@ -2638,7 +2643,7 @@ public class DifferentialFunctionFactory {
return new Polygamma(sameDiff, n,x).outputVariable(); return new Polygamma(sameDiff, n,x).outputVariable();
} }
public SDVariable roll(SDVariable input, SDVariable shift) { public SDVariable roll(SDVariable input, int shift) {
return new Roll(sameDiff, input, shift).outputVariable(); return new Roll(sameDiff, input, shift).outputVariable();
} }

View File

@ -787,9 +787,10 @@ public abstract class SDBaseOps {
* @param number Number of values to generate * @param number Number of values to generate
* @return SDVariable with linearly spaced elements * @return SDVariable with linearly spaced elements
*/ */
public SDVariable linspace(DataType dataType, double start, double stop, long number) { // TODO: fix or remove, currently it is internal recursion
/*public SDVariable linspace(DataType dataType, double start, double stop, long number) {
return linspace(dataType, start, stop, number); return linspace(dataType, start, stop, number);
} }*/
/** /**
* Create a new 1d array with values evenly spaced between values 'start' and 'stop' * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
@ -3093,6 +3094,9 @@ public abstract class SDBaseOps {
return zerosLike(null, input); return zerosLike(null, input);
} }
public SDVariable zerosLike(@NonNull SDVariable input, @NonNull DataType dataType) {
return zerosLike(null, input, dataType);
}
/** /**
* Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic: * Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic:
* if the input shape changes in later execution, the returned variable's shape will also be updated * if the input shape changes in later execution, the returned variable's shape will also be updated
@ -3106,6 +3110,10 @@ public abstract class SDBaseOps {
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
public SDVariable zerosLike(String name, @NonNull SDVariable input, @NonNull DataType dataType) {
SDVariable ret = f().zerosLike(name, input, dataType);
return updateVariableNameAndReference(ret, name);
}
/** /**
* See {@link #any(String, SDVariable, int...)} * See {@link #any(String, SDVariable, int...)}

View File

@ -2545,7 +2545,7 @@ public class SDMath extends SDOps {
* @param shift number of places to shift elements * @param shift number of places to shift elements
* @return array * @return array
*/ */
public SDVariable roll(String name, SDVariable input, SDVariable shift) { public SDVariable roll(String name, SDVariable input, int shift) {
SDVariable res = f().roll(input,shift); SDVariable res = f().roll(input,shift);
return updateVariableNameAndReference(res, name); return updateVariableNameAndReference(res, name);
} }

View File

@ -815,8 +815,9 @@ public class FlatBuffersMapper {
} }
int[] dims; int[] dims;
if (node.opType() == Op.Type.REDUCE_FLOAT || node.opType() == Op.Type.REDUCE_SAME || node.opType() == Op.Type.REDUCE_BOOL Type t = node.opType();
|| node.opType() == Op.Type.REDUCE_LONG || node.opType() == Op.Type.INDEXREDUCE || node.opType() == Op.Type.REDUCE3) { if (t == Op.Type.REDUCE_FLOAT || t == Op.Type.REDUCE_SAME || t == Op.Type.REDUCE_BOOL
|| t == Op.Type.REDUCE_LONG || t == Op.Type.INDEXREDUCE || t == Op.Type.REDUCE3 || t == Type.VARIANCE || t == Type.SUMMARYSTATS) {
dims = node.getDimensions(); dims = node.getDimensions();
if (dims == null) if (dims == null)
dims = new int[0]; dims = new int[0];

View File

@ -16,6 +16,7 @@
package org.nd4j.autodiff.validation; package org.nd4j.autodiff.validation;
import org.nd4j.linalg.api.ops.custom.*;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
import org.nd4j.linalg.api.ops.impl.reduce.HashCode; import org.nd4j.linalg.api.ops.impl.reduce.HashCode;
@ -38,10 +39,6 @@ import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOpDescriptor; import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces;
import org.nd4j.linalg.api.ops.custom.BarnesHutGains;
import org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize;
import org.nd4j.linalg.api.ops.custom.SpTreeCell;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.*; import org.nd4j.linalg.api.ops.impl.broadcast.bool.*;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.loss.bp.*; import org.nd4j.linalg.api.ops.impl.loss.bp.*;
@ -1011,7 +1008,10 @@ public class OpValidation {
SpTreeCell.class, SpTreeCell.class,
CbowRound.class, CbowRound.class,
SkipGramRound.class, SkipGramRound.class,
HashCode.class HashCode.class,
HashCode.class,
BitCast.class,
ToggleBits.class
); );
return new HashSet<>(list); return new HashSet<>(list);

View File

@ -200,7 +200,6 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul.class, org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul.class,
org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp.class, org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp.class,
org.nd4j.linalg.api.ops.impl.reduce.floating.AMean.class, org.nd4j.linalg.api.ops.impl.reduce.floating.AMean.class,
org.nd4j.linalg.api.ops.impl.reduce.floating.Bias.class,
org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy.class, org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy.class,
org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy.class, org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy.class,
org.nd4j.linalg.api.ops.impl.reduce.floating.Mean.class, org.nd4j.linalg.api.ops.impl.reduce.floating.Mean.class,

View File

@ -16,21 +16,28 @@
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays;
import java.util.List;
/** /**
* This op takes arbitrary number of arrays as input, and returns single "flattened" vector * This op takes arbitrary number of arrays as input, and returns single "flattened" vector
* *
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
@Data
@NoArgsConstructor
public class Flatten extends DynamicCustomOp { public class Flatten extends DynamicCustomOp {
private char order; private int order;
public Flatten() {
//
}
public Flatten(char order, INDArray... inputs) { public Flatten(char order, INDArray... inputs) {
this.order = order; this.order = order;
@ -47,10 +54,21 @@ public class Flatten extends DynamicCustomOp {
outputArguments.add(output); outputArguments.add(output);
} }
public Flatten(SameDiff sameDiff, char order, SDVariable... inputs) {
super(sameDiff, inputs);
this.order = order;
addIArgument(order);
}
@Override @Override
public String opName() { public String opName() {
return "flatten"; return "flatten";
} }
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Arrays.asList(inputDataTypes.get(0));
}
} }

View File

@ -51,6 +51,14 @@ public class FusedBatchNorm extends DynamicCustomOp {
public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset, public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset,
@NonNull SDVariable dataFormat, @NonNull SDVariable isTraining) { @NonNull SDVariable dataFormat, @NonNull SDVariable isTraining) {
super("", sameDiff, new SDVariable[]{x, scale, offset, dataFormat, isTraining}); super("", sameDiff, new SDVariable[]{x, scale, offset, dataFormat, isTraining});
this.outputDataType = x.dataType();
}
public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset,
int dataFormat, int isTraining) {
super("", sameDiff, new SDVariable[]{x, scale, offset});
addIArgument(dataFormat, isTraining);
this.outputDataType = x.dataType();
} }
@Override @Override
@ -78,6 +86,8 @@ public class FusedBatchNorm extends DynamicCustomOp {
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length; int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Arrays.asList(outputDataType, DataType.FLOAT, DataType.FLOAT); //Activations may be half, bfloat16, float32; mean/var is always float return Arrays.asList(outputDataType == null ? DataType.FLOAT : outputDataType,
outputDataType == null ? DataType.FLOAT : outputDataType,
outputDataType == null ? DataType.FLOAT : outputDataType);
} }
} }

View File

@ -64,6 +64,6 @@ public class Lu extends DynamicCustomOp {
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length; int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Arrays.asList(inputDataTypes.get(0), indexDataType); return Arrays.asList(inputDataTypes.get(0), indexDataType == null ? DataType.INT32 : indexDataType);
} }
} }

View File

@ -46,6 +46,11 @@ public class MatrixBandPart extends DynamicCustomOp {
super("", sameDiff, new SDVariable[]{input, minLower, maxUpper}); super("", sameDiff, new SDVariable[]{input, minLower, maxUpper});
} }
public MatrixBandPart(@NonNull SameDiff sameDiff, @NonNull SDVariable input, int minLower, int maxUpper) {
super("", sameDiff, new SDVariable[]{input});
addIArgument(minLower, maxUpper);
}
@Override @Override
public String opName() { public String opName() {
return "matrix_band_part"; return "matrix_band_part";

View File

@ -45,6 +45,15 @@ public class Roll extends DynamicCustomOp {
super("", sameDiff, new SDVariable[]{input,shift}); super("", sameDiff, new SDVariable[]{input,shift});
} }
public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable axes, @NonNull SDVariable shift) {
super("", sameDiff, new SDVariable[]{input,axes,shift});
}
public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, int shift) {
super("", sameDiff, new SDVariable[]{input});
addIArgument(shift);
}
@Override @Override
public String opName() { public String opName() {
return "roll"; return "roll";

View File

@ -7,9 +7,13 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
@NoArgsConstructor @NoArgsConstructor
public class TriangularSolve extends DynamicCustomOp { public class TriangularSolve extends DynamicCustomOp {
@ -24,11 +28,27 @@ public class TriangularSolve extends DynamicCustomOp {
super(sameDiff, new SDVariable[] {matrix, rhs, lower, adjoint}); super(sameDiff, new SDVariable[] {matrix, rhs, lower, adjoint});
} }
public TriangularSolve(SameDiff sameDiff, SDVariable matrix, SDVariable rhs,
boolean lower, boolean adjoint) {
super(sameDiff, new SDVariable[] {matrix, rhs});
addBArgument(lower, adjoint);
}
@Override @Override
public String opName() { public String opName() {
return "triangular_solve"; return "triangular_solve";
} }
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
if(attributesForNode.containsKey("adjoint")){
addBArgument(attributesForNode.get("adjoint").getB());
}
if(attributesForNode.containsKey("lower")){
addBArgument(attributesForNode.get("lower").getB());
}
}
@Override @Override
public String tensorflowName() { public String tensorflowName() {
return "MatrixTriangularSolve"; return "MatrixTriangularSolve";

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.broadcast; package org.nd4j.linalg.api.ops.impl.broadcast;
import lombok.NoArgsConstructor;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -27,6 +28,7 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
@NoArgsConstructor
public class BiasAddGrad extends DynamicCustomOp { public class BiasAddGrad extends DynamicCustomOp {
protected boolean nchw = true; protected boolean nchw = true;
@ -40,7 +42,16 @@ public class BiasAddGrad extends DynamicCustomOp {
super(new INDArray[]{input, bias, gradient}, wrapOrNull(output)); super(new INDArray[]{input, bias, gradient}, wrapOrNull(output));
} }
public BiasAddGrad() {} public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient,
boolean nchw) {
addInputArgument(input, bias, gradient);
this.nchw = nchw;
addBArgument(nchw);
}
public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient) {
this(input, bias, gradient, false);
}
@Override @Override
public int opNum() { public int opNum() {

View File

@ -16,6 +16,8 @@
package org.nd4j.linalg.api.ops.impl.indexaccum; package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -35,6 +37,8 @@ import java.util.List;
* *
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
@Data
@NoArgsConstructor
public class FirstIndex extends BaseIndexAccumulation { public class FirstIndex extends BaseIndexAccumulation {
protected Condition condition; protected Condition condition;
protected double compare; protected double compare;
@ -50,9 +54,6 @@ public class FirstIndex extends BaseIndexAccumulation {
this.extraArgs = new Object[] {compare, eps, (double) mode}; this.extraArgs = new Object[] {compare, eps, (double) mode};
} }
public FirstIndex() {}
public FirstIndex(INDArray x, @NonNull Condition condition, int... dimension) { public FirstIndex(INDArray x, @NonNull Condition condition, int... dimension) {
this(x, condition, false, dimension); this(x, condition, false, dimension);
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum; package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -30,6 +31,7 @@ import java.util.List;
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@Data
public class IAMax extends BaseIndexAccumulation { public class IAMax extends BaseIndexAccumulation {
public IAMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) { public IAMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
super(sameDiff, i_v, keepDims, dimensions); super(sameDiff, i_v, keepDims, dimensions);

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum; package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -30,6 +31,7 @@ import java.util.List;
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@Data
public class IAMin extends BaseIndexAccumulation { public class IAMin extends BaseIndexAccumulation {
public IAMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) { public IAMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
super(sameDiff, i_v, keepDims, dimensions); super(sameDiff, i_v, keepDims, dimensions);

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum; package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
@ -31,6 +32,7 @@ import java.util.List;
* *
* @author Alex Black * @author Alex Black
*/ */
@Data
public class IMax extends BaseIndexAccumulation { public class IMax extends BaseIndexAccumulation {
public IMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) { public IMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
super(sameDiff, i_v, keepDims, dimensions); super(sameDiff, i_v, keepDims, dimensions);

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum; package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
@ -30,6 +31,7 @@ import java.util.List;
* *
* @author Alex Black * @author Alex Black
*/ */
@Data
public class IMin extends BaseIndexAccumulation { public class IMin extends BaseIndexAccumulation {
public IMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) { public IMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
super(sameDiff, i_v, keepDims, dimensions); super(sameDiff, i_v, keepDims, dimensions);

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum; package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -36,6 +37,7 @@ import java.util.Map;
* *
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
@Data
public class LastIndex extends BaseIndexAccumulation { public class LastIndex extends BaseIndexAccumulation {
protected Condition condition; protected Condition condition;
protected double compare; protected double compare;

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom; package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
@ -29,6 +30,7 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@Data
public class ArgMax extends DynamicCustomOp { public class ArgMax extends DynamicCustomOp {
protected DataType outputType; protected DataType outputType;

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom; package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
@ -34,6 +35,7 @@ import java.util.Map;
* *
* @author Alex Black * @author Alex Black
*/ */
@Data
public class ArgMin extends DynamicCustomOp { public class ArgMin extends DynamicCustomOp {
protected DataType outputType = DataType.LONG; protected DataType outputType = DataType.LONG;

View File

@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
@ -38,8 +39,14 @@ public class SoftmaxCrossEntropyWithLogitsLoss extends DynamicCustomOp {
protected int classesDim; protected int classesDim;
public SoftmaxCrossEntropyWithLogitsLoss(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) { // public SoftmaxCrossEntropyWithLogitsLoss(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) {
super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false); // super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false);
// this.classesDim = classesDim;
// addIArgument(classesDim);
// }
public SoftmaxCrossEntropyWithLogitsLoss(SameDiff sameDiff, SDVariable logits, SDVariable labels, int classesDim) {
super(null, sameDiff, new SDVariable[]{logits, labels}, false);
this.classesDim = classesDim; this.classesDim = classesDim;
addIArgument(classesDim); addIArgument(classesDim);
} }
@ -66,7 +73,8 @@ public class SoftmaxCrossEntropyWithLogitsLoss extends DynamicCustomOp {
public List<SDVariable> doDiff(List<SDVariable> grad){ public List<SDVariable> doDiff(List<SDVariable> grad){
//No external gradient //No external gradient
//Args: logits, weigths, label //Args: logits, weigths, label
SDVariable[] grads = f().lossSoftmaxCrossEntropyWithLogitsBp(arg(2), arg(0), arg(1), classesDim); SDVariable[] args = args();
SDVariable[] grads = f().lossSoftmaxCrossEntropyWithLogitsBp(arg(0), arg(1), classesDim);
return Arrays.asList(grads); return Arrays.asList(grads);
} }
} }

View File

@ -20,14 +20,18 @@ import lombok.NoArgsConstructor;
import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.loss.BaseLoss; import org.nd4j.linalg.api.ops.impl.loss.BaseLoss;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
import java.util.Arrays;
import java.util.List;
import java.util.Map; import java.util.Map;
@ -56,4 +60,12 @@ public class SoftmaxCrossEntropyLossBp extends BaseLossBp {
public String opName() { public String opName() {
return "softmax_cross_entropy_loss_grad"; return "softmax_cross_entropy_loss_grad";
} }
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
"Expected 2 or 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
return Arrays.asList(inputDataTypes.get(0), inputDataTypes.get(1), inputDataTypes.get(2)); //Same as predictions
}
} }

View File

@ -19,8 +19,10 @@ package org.nd4j.linalg.api.ops.impl.loss.bp;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays;
import java.util.List; import java.util.List;
@ -34,8 +36,8 @@ public class SoftmaxCrossEntropyWithLogitsLossBp extends DynamicCustomOp {
protected int classesDim; protected int classesDim;
public SoftmaxCrossEntropyWithLogitsLossBp(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) { public SoftmaxCrossEntropyWithLogitsLossBp(SameDiff sameDiff, SDVariable logits, SDVariable labels, int classesDim) {
super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false); super(null, sameDiff, new SDVariable[]{logits, labels}, false);
this.classesDim = classesDim; this.classesDim = classesDim;
addIArgument(classesDim); addIArgument(classesDim);
} }
@ -49,4 +51,9 @@ public class SoftmaxCrossEntropyWithLogitsLossBp extends DynamicCustomOp {
public List<SDVariable> doDiff(List<SDVariable> grad){ public List<SDVariable> doDiff(List<SDVariable> grad){
throw new UnsupportedOperationException("Differentiation of " + getClass().getName() + " not supported"); throw new UnsupportedOperationException("Differentiation of " + getClass().getName() + " not supported");
} }
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
return Arrays.asList(arg(0).dataType(), arg(1).dataType());
}
} }

View File

@ -16,9 +16,12 @@
package org.nd4j.linalg.api.ops.impl.reduce; package org.nd4j.linalg.api.ops.impl.reduce;
import lombok.NoArgsConstructor;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.*; import java.util.*;
@ -30,12 +33,9 @@ import java.util.*;
* *
* @author Alex Black * @author Alex Black
*/ */
@NoArgsConstructor
public class SufficientStatistics extends DynamicCustomOp { public class SufficientStatistics extends DynamicCustomOp {
public SufficientStatistics() {
}
public SufficientStatistics(SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable axis, SDVariable shift) { public SufficientStatistics(SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable axis, SDVariable shift) {
super(null, sameDiff, argsNoNull(x, axis, shift), false); super(null, sameDiff, argsNoNull(x, axis, shift), false);
} }
@ -48,14 +48,30 @@ public class SufficientStatistics extends DynamicCustomOp {
} }
} }
public SufficientStatistics(@NonNull INDArray x, @NonNull INDArray axes, INDArray shift) {
if (shift != null)
addInputArgument(x, axes, shift);
else
addInputArgument(x, axes);
}
public SufficientStatistics(@NonNull INDArray x, @NonNull INDArray axes) {
this(x,axes,null);
}
@Override @Override
public String opName() { public String opName() {
return "sufficient_statistics"; return "sufficient_statistics";
} }
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> grad) { public List<SDVariable> doDiff(List<SDVariable> grad) {
throw new UnsupportedOperationException("Backprop not yet implemented for op: " + getClass().getSimpleName()); throw new UnsupportedOperationException("Backprop not yet implemented for op: " + getClass().getSimpleName());
} }
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
// FIXME
return Arrays.asList(inputDataTypes.get(0), inputDataTypes.get(0),inputDataTypes.get(0));
}
} }

View File

@ -111,7 +111,7 @@ public class TensorMmul extends DynamicCustomOp {
int[][] deletedAxes = new int[][]{ int[][] deletedAxes = new int[][]{
removeIndex(aAxes, sumAxes[0]), removeIndex(aAxes, sumAxes[0]),
removeIndex(bAxes, sumAxes[1])}; removeIndex(bAxes, sumAxes[1])};
int[] gAxes = range(0, i_v1.get(0).getShape().length); int[] gAxes = range(0, i_v1.get(0).eval().shape().length);
int[][] firstAxes = new int[][]{ int[][] firstAxes = new int[][]{
Arrays.copyOfRange(gAxes, deletedAxes[0].length, gAxes.length), Arrays.copyOfRange(gAxes, deletedAxes[0].length, gAxes.length),
deletedAxes[1] deletedAxes[1]
@ -144,18 +144,20 @@ public class TensorMmul extends DynamicCustomOp {
int[][] axes) { int[][] axes) {
int validationLength = Math.min(axes[0].length, axes[1].length); int validationLength = Math.min(axes[0].length, axes[1].length);
INDArray aArray = a.eval();
INDArray bArray = b.eval();
for (int i = 0; i < validationLength; i++) { for (int i = 0; i < validationLength; i++) {
if (a.getShape()[axes[0][i]] != b.getShape()[axes[1][i]]) if (aArray.shape()[axes[0][i]] != bArray.shape()[axes[1][i]])
throw new IllegalArgumentException("Size of the given axes at each dimension must be the same size."); throw new IllegalArgumentException("Size of the given axes at each dimension must be the same size.");
if (axes[0][i] < 0) if (axes[0][i] < 0)
axes[0][i] += a.getShape().length; axes[0][i] += aArray.shape().length;
if (axes[1][i] < 0) if (axes[1][i] < 0)
axes[1][i] += b.getShape().length; axes[1][i] += bArray.shape().length;
} }
List<Integer> listA = new ArrayList<>(); List<Integer> listA = new ArrayList<>();
for (int i = 0; i < a.getShape().length; i++) { for (int i = 0; i < aArray.shape().length; i++) {
if (!Ints.contains(axes[0], i)) if (!Ints.contains(axes[0], i))
listA.add(i); listA.add(i);
} }
@ -164,7 +166,7 @@ public class TensorMmul extends DynamicCustomOp {
List<Integer> listB = new ArrayList<>(); List<Integer> listB = new ArrayList<>();
for (int i = 0; i < b.getShape().length; i++) { for (int i = 0; i < bArray.shape().length; i++) {
if (!Ints.contains(axes[1], i)) if (!Ints.contains(axes[1], i))
listB.add(i); listB.add(i);
} }
@ -172,9 +174,9 @@ public class TensorMmul extends DynamicCustomOp {
int[] newAxesB = Ints.concat(axes[1], Ints.toArray(listB)); int[] newAxesB = Ints.concat(axes[1], Ints.toArray(listB));
int n2 = 1; int n2 = 1;
int aLength = Math.min(a.getShape().length, axes[0].length); int aLength = Math.min(aArray.shape().length, axes[0].length);
for (int i = 0; i < aLength; i++) { for (int i = 0; i < aLength; i++) {
n2 *= a.getShape()[axes[0][i]]; n2 *= aArray.shape()[axes[0][i]];
} }
//if listA and listB are empty these do not initialize. //if listA and listB are empty these do not initialize.
@ -186,13 +188,13 @@ public class TensorMmul extends DynamicCustomOp {
} else { } else {
oldShapeA = Longs.toArray(listA); oldShapeA = Longs.toArray(listA);
for (int i = 0; i < oldShapeA.length; i++) for (int i = 0; i < oldShapeA.length; i++)
oldShapeA[i] = a.getShape()[(int) oldShapeA[i]]; oldShapeA[i] = aArray.shape()[(int) oldShapeA[i]];
} }
int n3 = 1; int n3 = 1;
int bNax = Math.min(b.getShape().length, axes[1].length); int bNax = Math.min(bArray.shape().length, axes[1].length);
for (int i = 0; i < bNax; i++) { for (int i = 0; i < bNax; i++) {
n3 *= b.getShape()[axes[1][i]]; n3 *= bArray.shape()[axes[1][i]];
} }
@ -203,7 +205,7 @@ public class TensorMmul extends DynamicCustomOp {
} else { } else {
oldShapeB = Longs.toArray(listB); oldShapeB = Longs.toArray(listB);
for (int i = 0; i < oldShapeB.length; i++) for (int i = 0; i < oldShapeB.length; i++)
oldShapeB[i] = b.getShape()[(int) oldShapeB[i]]; oldShapeB[i] = bArray.shape()[(int) oldShapeB[i]];
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.reduce.bp; package org.nd4j.linalg.api.ops.impl.reduce.bp;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -30,7 +31,7 @@ import java.util.List;
/** /**
* @author Alex Black * @author Alex Black
*/ */
@NoArgsConstructor
public abstract class BaseReductionBp extends DynamicCustomOp { public abstract class BaseReductionBp extends DynamicCustomOp {
protected boolean keepDims; protected boolean keepDims;
@ -96,7 +97,12 @@ public abstract class BaseReductionBp extends DynamicCustomOp {
addArgs(); addArgs();
} }
public BaseReductionBp(){} public BaseReductionBp(INDArray origInput1, INDArray origInput2, INDArray gradAtOutput, INDArray output1, INDArray output2, boolean keepDims, int... dimensions){
super(null, new INDArray[]{origInput1, origInput2, gradAtOutput}, new INDArray[]{output1, output2});
this.keepDims = keepDims;
this.dimensions = dimensions;
addArgs();
}
protected void addArgs(){ protected void addArgs(){
addTArgument(keepDims ? 1 : 0); addTArgument(keepDims ? 1 : 0);

View File

@ -16,17 +16,23 @@
package org.nd4j.linalg.api.ops.impl.reduce.bp; package org.nd4j.linalg.api.ops.impl.reduce.bp;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
/** /**
* Backprop op for Dot pairwise reduction operation * Backprop op for Dot pairwise reduction operation
* *
* @author Alex Black * @author Alex Black
*/ */
@NoArgsConstructor
public class DotBp extends BaseReductionBp { public class DotBp extends BaseReductionBp {
public DotBp(SameDiff sameDiff, SDVariable origInput1, SDVariable origInput2, SDVariable gradAtOutput, boolean keepDims, int... dimensions) { public DotBp(SameDiff sameDiff, SDVariable origInput1, SDVariable origInput2, SDVariable gradAtOutput, boolean keepDims, int... dimensions) {
@ -37,10 +43,22 @@ public class DotBp extends BaseReductionBp {
super(origInput1, origInput2, gradAtOutput, output, keepDims, dimensions); super(origInput1, origInput2, gradAtOutput, output, keepDims, dimensions);
} }
public DotBp(){} public DotBp(INDArray origInput1, INDArray origInput2, INDArray gradAtOutput,
INDArray outputX, INDArray outputY, boolean keepDims, int... dimensions) {
super(origInput1, origInput2, gradAtOutput, outputX, outputY, keepDims, dimensions);
}
@Override @Override
public String opName() { public String opName() {
return "reduce_dot_bp"; return "reduce_dot_bp";
} }
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 input datatype for %s, got input %s", getClass(), dataTypes);
Preconditions.checkState(dataTypes.get(0).isFPType(), "First input must be a floating point type, got %s", dataTypes.get(0));
Preconditions.checkState(dataTypes.get(1).isFPType(), "Second input (gradient at reduction output) must be a floating point type, got %s", dataTypes.get(1));
Preconditions.checkState(dataTypes.get(2).isFPType(), "Second input (gradient at reduction output) must be a floating point type, got %s", dataTypes.get(2));
return Arrays.asList(dataTypes.get(0), dataTypes.get(0));
}
} }

View File

@ -1,86 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.reduce.floating;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseReduceFloatOp;
import org.nd4j.linalg.api.ops.BaseReduceOp;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
/**
* Calculate a bias
*
* @author Adam Gibson
*/
public class Bias extends BaseReduceFloatOp {
private double mean;
public Bias(SameDiff sameDiff, SDVariable i_v, int[] dimensions, double mean) {
super(sameDiff, i_v, dimensions);
this.mean = mean;
}
public Bias(SameDiff sameDiff, SDVariable i_v, SDVariable i_v2, int[] dimensions, double mean) {
super(sameDiff, i_v, i_v2, dimensions);
this.mean = mean;
}
public Bias() {}
public Bias(INDArray x, int... dimensions) {
super(x, dimensions);
}
@Override
public Map<String, Object> propertiesForFunction() {
Map<String,Object> ret = new LinkedHashMap<>();
ret.put("mean",mean);
return ret;
}
@Override
public int opNum() {
return 2;
}
@Override
public String opName() {
return "bias";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return null;
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
}

View File

@ -24,6 +24,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -45,6 +46,7 @@ public class SequenceMask extends DynamicCustomOp {
public SequenceMask(SameDiff sameDiff, SDVariable input, SDVariable maxLen, DataType dataType) { public SequenceMask(SameDiff sameDiff, SDVariable input, SDVariable maxLen, DataType dataType) {
super(null, sameDiff, new SDVariable[] {input, maxLen}, false); super(null, sameDiff, new SDVariable[] {input, maxLen}, false);
this.dataType = dataType; this.dataType = dataType;
addDArgument(dataType);
} }
public SequenceMask(SameDiff sameDiff, SDVariable input, int maxLen, DataType dataType) { public SequenceMask(SameDiff sameDiff, SDVariable input, int maxLen, DataType dataType) {
@ -53,13 +55,23 @@ public class SequenceMask extends DynamicCustomOp {
this.is_static_maxlen = true; this.is_static_maxlen = true;
addIArgument(maxLen); addIArgument(maxLen);
this.dataType = dataType; this.dataType = dataType;
addDArgument(dataType);
} }
public SequenceMask(SameDiff sameDiff, SDVariable input, DataType dataType) { public SequenceMask(SameDiff sameDiff, SDVariable input, DataType dataType) {
super(null, sameDiff, new SDVariable[] {input}, false); super(null, sameDiff, new SDVariable[] {input}, false);
this.dataType = dataType; this.dataType = dataType;
addDArgument(dataType);
} }
public SequenceMask(INDArray input, int maxLen, DataType dataType) {
addInputArgument(input);
addIArgument(maxLen);
//addIArgument(dataType.toInt());
addDArgument(dataType);
this.dataType = dataType;
}
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -39,23 +40,37 @@ import java.util.Map;
* @author Adam Gibson * @author Adam Gibson
*/ */
@Slf4j @Slf4j
@NoArgsConstructor
public class ZerosLike extends DynamicCustomOp { public class ZerosLike extends DynamicCustomOp {
protected DataType outputType; //Allow customizing dtype for TF import protected DataType outputType; //Allow customizing dtype for TF import
public ZerosLike() { public ZerosLike(String name, SameDiff sameDiff, SDVariable input) {
this(name, sameDiff, input, false, input.dataType());
} }
public ZerosLike(String name, SameDiff sameDiff, SDVariable input) { public ZerosLike(String name, SameDiff sameDiff, SDVariable input, DataType dataType) {
this(name, sameDiff, input, false); this(name, sameDiff, input, false, dataType);
} }
public ZerosLike(String name, SameDiff sameDiff, SDVariable input, boolean inPlace) { public ZerosLike(String name, SameDiff sameDiff, SDVariable input, boolean inPlace) {
this(name, sameDiff, input, inPlace, input.dataType());
}
public ZerosLike(String name, SameDiff sameDiff, SDVariable input, boolean inPlace, DataType dataType) {
super(name, sameDiff, new SDVariable[]{input}, inPlace); super(name, sameDiff, new SDVariable[]{input}, inPlace);
addDArgument(dataType);
} }
public ZerosLike(INDArray in, INDArray out){ public ZerosLike(INDArray in, INDArray out){
this(in, out, in.dataType());
}
public ZerosLike(INDArray in, INDArray out, DataType dataType) {
super(null, in, out, null, null); super(null, in, out, null, null);
if (dataType != null) {
addDArgument(dataType);
}
} }

View File

@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
@ -52,16 +53,13 @@ public class BatchToSpace extends DynamicCustomOp {
} }
public BatchToSpace(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] crops, boolean inPlace) { public BatchToSpace(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] crops, boolean inPlace) {
super(null, sameDiff, args, inPlace); super(null, sameDiff, new SDVariable[]{args[0], sameDiff.constant(Nd4j.createFromArray(crops))}, inPlace);
this.blocks = blocks; this.blocks = blocks;
this.crops = crops; this.crops = crops;
for (val b : blocks) for (val b : blocks)
addIArgument(b); addIArgument(b);
for (int e = 0; e < crops.length; e++)
addIArgument(crops[e][0], crops[e][1]);
} }
@Override @Override

View File

@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
@ -53,16 +54,12 @@ public class SpaceToBatch extends DynamicCustomOp {
} }
public SpaceToBatch(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] padding, boolean inPlace) { public SpaceToBatch(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] padding, boolean inPlace) {
super(null, sameDiff, args, inPlace); super(null, sameDiff, new SDVariable[]{args[0], sameDiff.constant(Nd4j.createFromArray(padding))}, inPlace);
this.blocks = blocks; this.blocks = blocks;
this.padding = padding; this.padding = padding;
for (val b : blocks) addIArgument(blocks[0]);
addIArgument(b);
for (int e = 0; e < padding.length; e++)
addIArgument(padding[e][0], padding[e][1]);
} }
@Override @Override

View File

@ -58,7 +58,8 @@ public class UnsortedSegmentMax extends DynamicCustomOp {
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
"Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0)); return Collections.singletonList(inputDataTypes.get(0));
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.segment; package org.nd4j.linalg.api.ops.impl.transforms.segment;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -31,6 +32,7 @@ import java.util.List;
* *
* @author Alex Black * @author Alex Black
*/ */
@NoArgsConstructor
public class UnsortedSegmentMean extends DynamicCustomOp { public class UnsortedSegmentMean extends DynamicCustomOp {
private int numSegments; private int numSegments;
@ -41,8 +43,6 @@ public class UnsortedSegmentMean extends DynamicCustomOp {
addIArgument(numSegments); addIArgument(numSegments);
} }
public UnsortedSegmentMean(){ }
@Override @Override
public String opName(){ public String opName(){
return "unsorted_segment_mean"; return "unsorted_segment_mean";
@ -60,7 +60,8 @@ public class UnsortedSegmentMean extends DynamicCustomOp {
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
"Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0)); return Collections.singletonList(inputDataTypes.get(0));
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.segment; package org.nd4j.linalg.api.ops.impl.transforms.segment;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -31,6 +32,7 @@ import java.util.List;
* *
* @author Alex Black * @author Alex Black
*/ */
@NoArgsConstructor
public class UnsortedSegmentMin extends DynamicCustomOp { public class UnsortedSegmentMin extends DynamicCustomOp {
private int numSegments; private int numSegments;
@ -41,8 +43,6 @@ public class UnsortedSegmentMin extends DynamicCustomOp {
addIArgument(numSegments); addIArgument(numSegments);
} }
public UnsortedSegmentMin(){ }
@Override @Override
public String opName(){ public String opName(){
return "unsorted_segment_min"; return "unsorted_segment_min";
@ -60,7 +60,8 @@ public class UnsortedSegmentMin extends DynamicCustomOp {
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
"Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0)); return Collections.singletonList(inputDataTypes.get(0));
} }
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.segment; package org.nd4j.linalg.api.ops.impl.transforms.segment;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -31,6 +32,7 @@ import java.util.List;
* *
* @author Alex Black * @author Alex Black
*/ */
@NoArgsConstructor
public class UnsortedSegmentProd extends DynamicCustomOp { public class UnsortedSegmentProd extends DynamicCustomOp {
private int numSegments; private int numSegments;
@ -41,8 +43,6 @@ public class UnsortedSegmentProd extends DynamicCustomOp {
addIArgument(numSegments); addIArgument(numSegments);
} }
public UnsortedSegmentProd(){ }
@Override @Override
public String opName(){ public String opName(){
return "unsorted_segment_prod"; return "unsorted_segment_prod";
@ -60,7 +60,8 @@ public class UnsortedSegmentProd extends DynamicCustomOp {
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
"Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0)); return Collections.singletonList(inputDataTypes.get(0));
} }
} }

View File

@ -16,10 +16,12 @@
package org.nd4j.linalg.api.ops.impl.transforms.segment; package org.nd4j.linalg.api.ops.impl.transforms.segment;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.ArrayList; import java.util.ArrayList;
@ -31,18 +33,23 @@ import java.util.List;
* *
* @author Alex Black * @author Alex Black
*/ */
@NoArgsConstructor
public class UnsortedSegmentSqrtN extends DynamicCustomOp { public class UnsortedSegmentSqrtN extends DynamicCustomOp {
private int numSegments; private int numSegments;
public UnsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments) {
addInputArgument(data, segmentIds);
addIArgument(numSegments);
this.numSegments = numSegments;
}
public UnsortedSegmentSqrtN(SameDiff sameDiff, SDVariable data, SDVariable segmentIds, int numSegments) { public UnsortedSegmentSqrtN(SameDiff sameDiff, SDVariable data, SDVariable segmentIds, int numSegments) {
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
this.numSegments = numSegments; this.numSegments = numSegments;
addIArgument(numSegments); addIArgument(numSegments);
} }
public UnsortedSegmentSqrtN(){ }
@Override @Override
public String opName(){ public String opName(){
return "unsorted_segment_sqrt_n"; return "unsorted_segment_sqrt_n";
@ -60,7 +67,8 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp {
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
"Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
List<DataType> out = new ArrayList<>(); List<DataType> out = new ArrayList<>();
for( int i=0; i<numSegments; i++ ){ for( int i=0; i<numSegments; i++ ){
out.add(inputDataTypes.get(0)); out.add(inputDataTypes.get(0));

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.segment; package org.nd4j.linalg.api.ops.impl.transforms.segment;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -32,6 +33,7 @@ import java.util.List;
* *
* @author Alex Black * @author Alex Black
*/ */
@NoArgsConstructor
public class UnsortedSegmentSum extends DynamicCustomOp { public class UnsortedSegmentSum extends DynamicCustomOp {
private int numSegments; private int numSegments;
@ -42,8 +44,6 @@ public class UnsortedSegmentSum extends DynamicCustomOp {
addIArgument(numSegments); addIArgument(numSegments);
} }
public UnsortedSegmentSum(){ }
@Override @Override
public String opName(){ public String opName(){
return "unsorted_segment_sum"; return "unsorted_segment_sum";
@ -61,7 +61,8 @@ public class UnsortedSegmentSum extends DynamicCustomOp {
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
"Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
//TODO Allow customizing output type //TODO Allow customizing output type
return Collections.singletonList(Nd4j.defaultFloatingPointType()); return Collections.singletonList(Nd4j.defaultFloatingPointType());
} }

View File

@ -50,6 +50,11 @@ public class LayerOpValidation extends BaseOpValidation {
super(backend); super(backend);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testXwPlusB() { public void testXwPlusB() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -319,7 +324,7 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void testIm2Col() { public void testIm2Col() {
OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/deeplearning4j/deeplearning4j/issues/6873 //OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/deeplearning4j/deeplearning4j/issues/6873
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}}; int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}};

View File

@ -32,6 +32,9 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.custom.*;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.StopGradient; import org.nd4j.linalg.api.ops.impl.controlflow.compat.StopGradient;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.api.ops.impl.shape.DiagPart; import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
@ -513,7 +516,7 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
public void testTrace(){ public void testTrace(){
//TODO need to work out how to handle shape_op for scalars... //TODO need to work out how to handle shape_op for scalars...
OpValidationSuite.ignoreFailing(); //OpValidationSuite.ignoreFailing();
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
for( int[] inShape : new int[][]{{3,3}}){ for( int[] inShape : new int[][]{{3,3}}){
@ -546,12 +549,15 @@ public class MiscOpValidation extends BaseOpValidation {
SDVariable x = sameDiff.var("x", arr); SDVariable x = sameDiff.var("x", arr);
SDVariable y = sameDiff.var("y", arr2); SDVariable y = sameDiff.var("y", arr2);
SDVariable result = sameDiff.tensorMmul(x, y, new int[][]{{0}, {1}}); SDVariable result = sameDiff.tensorMmul(x, y, new int[][]{{0}, {1}});
assertArrayEquals(ArrayUtil.getTensorMmulShape(new long[]{2, 2, 2}, new long[]{2, 2, 2}, new int[][]{{0}, {1}}), result.getShape()); assertArrayEquals(ArrayUtil.getTensorMmulShape(new long[]{2, 2, 2}, new long[]{2, 2, 2}, new int[][]{{0}, {1}}),
assertEquals(32, sameDiff.numElements()); result.eval().shape());
assertEquals(16, sameDiff.numElements());
SDVariable loss = sameDiff.standardDeviation(result, true); SDVariable loss = sameDiff.standardDeviation(result, true);
sameDiff.addLossVariable(loss);
String err = OpValidation.validate(new TestCase(sameDiff)); String err = OpValidation.validate(new TestCase(sameDiff));
assertNull(err);
} }
@Test @Test
@ -1782,4 +1788,338 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(exp, out); //Values in x not in y assertEquals(exp, out); //Values in x not in y
assertEquals(exp, outIdx); //Indices of the values in x not in y assertEquals(exp, outIdx); //Indices of the values in x not in y
} }
@Test
public void testDivideNoNan() {
OpValidationSuite.ignoreFailing(); //TODO: implement DivideNoNan.doDiff()
SameDiff sameDiff = SameDiff.create();
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
SDVariable input1 = sameDiff.var(in1);
SDVariable input2 = sameDiff.var(in2);
INDArray expected = Nd4j.ones(3,4);
SDVariable output = new DivideNoNan(sameDiff, input1, input2).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testDigamma() {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray expected = Nd4j.createFromArray(new double[]{
-0.5772157,0.42278433,0.9227843,1.2561177,1.5061177,1.7061176,1.8727844,2.0156415,2.1406415,2.2517526,2.3517525,2.4426618
}).reshape(3,4);
val tc = new OpTestCase(new Digamma(in1)).expectedOutput(0, expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testFlatten() {
SameDiff sameDiff = SameDiff.create();
INDArray x = Nd4j.linspace(DataType.DOUBLE, 1, 27, 1).reshape(3,3,3);
SDVariable sdx = sameDiff.var(x);
INDArray expected = Nd4j.linspace(DataType.DOUBLE,1,27,1);
SDVariable output = new Flatten(sameDiff, 'c', sdx).outputVariable();
SDVariable loss = sameDiff.standardDeviation(sdx, true);
sameDiff.addLossVariable(loss);
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testFusedBatchNorm() {
OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create();
INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*2*3*4).reshape(2,2,3,4);
INDArray scale = Nd4j.create(DataType.DOUBLE, 4);
scale.assign(0.5);
INDArray offset = Nd4j.create(DataType.DOUBLE, 4);
offset.assign(2.0);
SDVariable input1 = sameDiff.var(x);
SDVariable input2 = sameDiff.var(scale);
SDVariable input3 = sameDiff.var(offset);
INDArray expectedY = Nd4j.createFromArray(new double[]{
985.5258, 985.5258, 985.5258, 985.5258,
659.7321, 659.7321, 659.7321, 659.7321,
399.0972, 399.0972, 399.0972, 399.0972,
203.6210, 203.6210, 203.6210, 203.6210,
73.3036, 73.3036, 73.3036, 73.3036,
8.1448, 8.1448, 8.1448, 8.1448,
8.1448, 8.1448, 8.1448, 8.1448,
73.3036, 73.3036, 73.3036, 73.3036,
203.6210, 203.6210, 203.6210, 203.6210,
399.0972, 399.0972, 399.0972, 399.0972,
659.7321, 659.7321, 659.7321, 659.7321,
985.5258, 985.5258, 985.5258, 985.5258}).reshape(x.shape());
INDArray expectedBatchMean = Nd4j.createFromArray(new double[]{23., 24., 25., 26.});
INDArray expectedBatchVar = Nd4j.createFromArray(new double[]{208.00001526, 208.00001526, 208.00001526, 208.00001526});
SDVariable[] outputs = new FusedBatchNorm(sameDiff, input1, input2, input3, 0, 1).outputVariables();
SDVariable loss = sameDiff.standardDeviation(input1, true);
sameDiff.addLossVariable(loss);
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(outputs[0].name(), expectedY)
.expectedOutput(outputs[1].name(), expectedBatchMean)
.expectedOutput(outputs[2].name(), expectedBatchVar);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testIgamma() {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray expected = Nd4j.createFromArray(new double[]{
0.63212055,0.59399414,0.5768099,0.56652874,0.5595013,0.5542634,0.5501591,0.5463888,0.54329145,0.54048204,0.5378594,0.53233755
}).reshape(3,4);
val tc = new OpTestCase(new Igamma(in1, in2)).expectedOutput(0, expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testIgammaC() {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray expected = Nd4j.createFromArray(new double[]{
0.36787945,0.40600586,0.42319012,0.43347126,0.4404987,0.44573656,0.4498409,0.45361117,0.45670855,0.459518,0.46214062,0.46766248
}).reshape(3,4);
val tc = new OpTestCase(new Igammac(in1, in2)).expectedOutput(0, expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testLgamma() {
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(DataType.DOUBLE, 1, 12, 1).reshape(3, 4);
SDVariable sdInput = sameDiff.var(in);
INDArray expected = Nd4j.createFromArray(new double[]{
0.0,0.0,0.6931472,1.7917595,3.1780539,4.787492,6.5792513,8.525162,10.604603,12.801827,15.104413,17.502308
}).reshape(3,4);
SDVariable output = new Lgamma(sameDiff, sdInput).outputVariable();
SDVariable loss = sameDiff.standardDeviation(sdInput, true);
sameDiff.addLossVariable(loss);
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testLu() {
SameDiff sameDiff = SameDiff.create();
INDArray in1 = Nd4j.createFromArray(new double[]{
1., 2., 3., 0., 2., 3., 0., 0., 7.
}).reshape(3,3);
SDVariable input1 = sameDiff.var(in1);
INDArray expected = Nd4j.createFromArray(new double[]{
1., 2., 3., 0., 2., 3., 0., 0., 7
}).reshape(3,3);
INDArray pexpected = Nd4j.createFromArray(new int[]{
0, 1, 2
});
sameDiff.loss.l2Loss(input1);
SDVariable[] output = new Lu(sameDiff, input1).outputVariables();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output[0].name(), expected)
.expectedOutput(output[1].name(), pexpected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testMatrixBandPart() {
OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create();
INDArray input = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f,
0.7271f,0.1804f,0.5056f,0.8925f,
0.5461f,0.9234f,0.0856f,0.7938f}).reshape(3,4);
SDVariable sdInput = sameDiff.var(input);
SDVariable sdInput1 = sameDiff.constant(1);
SDVariable sdInput2 = sameDiff.constant(-1);
INDArray expected = Nd4j.createFromArray(new float[]{
0.7788f, 0.8012f, 0.7244f, 0.2309f,
0.7271f, 0.1804f, 0.5056f, 0.8925f,
0.f, 0.9234f, 0.0856f, 0.7938f
}).reshape(3,4);
sameDiff.loss.l2Loss(sdInput);
SDVariable output = new MatrixBandPart(sameDiff, sdInput, 1, -1).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testPolygamma() {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray expected = Nd4j.createFromArray(new double[]{
1.644934,-0.4041138,0.1189394,-0.03750069,0.01226151,-0.0041002957,0.001392272,-4.780109E-4,1.6549716E-4,-5.7675967E-5,2.0206635E-5,-7.1101636E-6
}).reshape(3,4);
val tc = new OpTestCase(new Polygamma(in1, in2)).expectedOutput(0, expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testTriangularSolve() {
INDArray a = Nd4j.createFromArray(new float[]{
3.f, 0.f, 0.f, 0.f,
2.f, 1.f, 0.f, 0.f,
1.f, 0.f, 1.f, 0.f,
1.f, 1.f, 1.f, 1.f
}).reshape(4,4);
INDArray b = Nd4j.createFromArray(new float[]{
4.f, 2.f, 4.f, 2.f
}).reshape(4,1);
INDArray expected = Nd4j.createFromArray(new float[]{
1.333333f, 2.0f, 4.0f, 2.0f
}).reshape(4,1);
val tc = new OpTestCase(new TriangularSolve(a, b, false, true)).expectedOutput(0, expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testBiasAdd() {
SameDiff sameDiff = SameDiff.create();
INDArray in1 = Nd4j.linspace(1, 12, 12);
INDArray in2 = Nd4j.linspace(1, 12, 12);
SDVariable input1 = sameDiff.var(in1);
SDVariable input2 = sameDiff.var(in2);
INDArray expected = Nd4j.createFromArray(new double[]{
2.0000, 4.0000, 6.0000, 8.0000, 10.0000, 12.0000, 14.0000, 16.0000, 18.0000, 20.0000, 22.0000, 24.0000
});
SDVariable output = new BiasAdd(sameDiff, input1, input2, false).outputVariable();
SDVariable loss = sameDiff.standardDeviation(input1, true);
sameDiff.addLossVariable(loss);
SDVariable loss2 = sameDiff.standardDeviation(input2, true);
sameDiff.addLossVariable(loss2);
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testBiasAddGrad() {
SameDiff sameDiff = SameDiff.create();
INDArray x = Nd4j.linspace(DataType.FLOAT,1, 24, 24).reshape(2,2,2,3);
INDArray grad = Nd4j.linspace(DataType.FLOAT, 0.1, 0.1, 24).reshape(2,2,2,3);
INDArray bias = Nd4j.createFromArray(new float[]{-1.f, -2.f, -3.f});
INDArray expected = Nd4j.createFromArray(new float[]{9.2f, 10.f , 10.8f});
OpTestCase tc = new OpTestCase(new BiasAddGrad(x, bias, grad,false)).
expectedOutput(0, grad).
expectedOutput(1, expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testRoll() {
INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42,
21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}).
reshape(2,2,4,2);
INDArray expected = Nd4j.createFromArray(new double[]{ 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42,
12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32,
21.41, 21.42, 22.11, 22.12
}).reshape(x.shape());
int shift = 6;
val tc = new OpTestCase(new Roll(x,shift)).expectedOutput(0,expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
} }

View File

@ -35,16 +35,20 @@ import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax; import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin; import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin;
import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss;
import org.nd4j.linalg.api.ops.impl.reduce.Moments; import org.nd4j.linalg.api.ops.impl.reduce.Moments;
import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments; import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments;
import org.nd4j.linalg.api.ops.impl.reduce.floating.AMean; import org.nd4j.linalg.api.ops.impl.reduce.SufficientStatistics;
import org.nd4j.linalg.api.ops.impl.reduce.floating.*;
import org.nd4j.linalg.api.ops.impl.reduce.same.ASum; import org.nd4j.linalg.api.ops.impl.reduce.same.ASum;
import org.nd4j.linalg.api.ops.impl.reduce3.*; import org.nd4j.linalg.api.ops.impl.reduce3.*;
import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil; import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
@ -96,7 +100,7 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test @Test
public void testZeroCount() { public void testZeroCount() {
List<String> allFailed = new ArrayList<>(); List<String> allFailed = new ArrayList<>();
for (int i = 0; i < 2; i++) { for (int i = 0; i < 21; i++) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
INDArray ia; INDArray ia;
@ -159,25 +163,25 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test @Test
public void testReductionGradientsSimple() { public void testReductionGradientsSimple() {
OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES //OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES
//Test reductions: final and only function //Test reductions: final and only function
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
for (int i = 0; i < 21; i++) { for (int i = 0; i < 21; i++) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
int nOut = 4; int nOut = 4;
int minibatch = 10; int minibatch = 10;
SDVariable input = sd.var("in", -1, nOut); SDVariable input = sd.var("in", minibatch, nOut);
INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100); INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100);
long length = nOut * minibatch; long length = nOut * minibatch;
SDVariable loss; SDVariable loss;
String name; String name;
TestCase tc = new TestCase(sd); TestCase tc = new TestCase(sd);
boolean gradCheck = true;
switch (i) { switch (i) {
case 0: case 0:
loss = sd.mean("loss", input); loss = sd.mean("loss", input);
@ -234,11 +238,13 @@ public class ReductionOpValidation extends BaseOpValidation {
loss = sd.math().countNonZero("loss", input); loss = sd.math().countNonZero("loss", input);
name = "countNonZero"; name = "countNonZero";
tc.expectedOutput("loss", Nd4j.scalar(inputArr.length())); tc.expectedOutput("loss", Nd4j.scalar(inputArr.length()));
gradCheck = false; //Long out, not floating point
break; break;
case 11: case 11:
loss = sd.math().countZero("loss", input); loss = sd.math().countZero("loss", input);
name = "countZero"; name = "countZero";
tc.expectedOutput("loss", Nd4j.scalar(0)); tc.expectedOutput("loss", Nd4j.scalar(0L));
gradCheck = false; //Long out, not floating point
break; break;
case 12: case 12:
loss = sd.math().amax("loss", input); loss = sd.math().amax("loss", input);
@ -272,7 +278,7 @@ public class ReductionOpValidation extends BaseOpValidation {
loss = sd.math().logSumExp("loss", input); loss = sd.math().logSumExp("loss", input);
INDArray expArr = Transforms.exp(inputArr); INDArray expArr = Transforms.exp(inputArr);
double sum = expArr.sumNumber().doubleValue(); double sum = expArr.sumNumber().doubleValue();
tc.expected("loss", Nd4j.create(new double[]{Math.log(sum)})); tc.expected("loss", Nd4j.scalar(Math.log(sum)));
break; break;
case 18: case 18:
inputArr = Nd4j.rand(minibatch, nOut); inputArr = Nd4j.rand(minibatch, nOut);
@ -307,9 +313,15 @@ public class ReductionOpValidation extends BaseOpValidation {
log.info("*** Starting test: " + msg); log.info("*** Starting test: " + msg);
sd.associateArrayWithVariable(inputArr, input); sd.associateArrayWithVariable(inputArr, input);
if(gradCheck) {
sd.addLossVariable(loss);
}
tc.testName(msg); tc.testName(msg);
if(!gradCheck){
tc.gradientCheck(false);
}
String error = OpValidation.validate(tc, true); String error = OpValidation.validate(tc, true);
if (error != null) if (error != null)
failed.add(error); failed.add(error);
@ -629,14 +641,14 @@ public class ReductionOpValidation extends BaseOpValidation {
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
for (int[] reduceDims : new int[][]{{Integer.MAX_VALUE}, {0, 1, 2}, {0}, {1}, {2}, {0, 1}, {0, 2}, {1, 2}}) { for (int[] reduceDims : new int[][]{{Integer.MAX_VALUE}, {0, 1, 2}, {0}, {1}, {2}, {0, 1}, {0, 2}, {1, 2}}) {
for (int i = 6; i < 7; i++) { for (int i = 0; i < 7; i++) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
sd.setLogExecution(false); sd.setLogExecution(false);
SDVariable in = sd.var("in", -1, d1, d2); SDVariable in = sd.var("in", d1, d1, d2);
SDVariable in2 = sd.var("in2", -1, d1, d2); SDVariable in2 = sd.var("in2", d0, d1, d2);
INDArray inArr = Nd4j.randn(new int[]{d0, d1, d2}).muli(100); INDArray inArr = Nd4j.randn(new int[]{d0, d1, d2}).muli(100);
INDArray in2Arr = Nd4j.randn(inArr.shape()).muli(100); INDArray in2Arr = Nd4j.randn(inArr.shape()).muli(100);
@ -645,40 +657,43 @@ public class ReductionOpValidation extends BaseOpValidation {
SDVariable reduced; SDVariable reduced;
String name; String name;
TestCase tc = new TestCase(sd); TestCase tc = new TestCase(sd);
Double maxRelError = null;
switch (i) { switch (i) {
case 0: case 0:
reduced = sd.math().manhattanDistance(in, in2, reduceDims); reduced = sd.math().manhattanDistance(in, in2, reduceDims);
name = "manhattan"; name = "manhattan";
exp = Nd4j.getExecutioner().exec(new ManhattanDistance(inArr, in2Arr, null, true, false, reduceDims)); exp = Nd4j.getExecutioner().exec(new ManhattanDistance(inArr, in2Arr, null, false, false, reduceDims));
break; break;
case 1: case 1:
reduced = sd.math().euclideanDistance(in, in2, reduceDims); reduced = sd.math().euclideanDistance(in, in2, reduceDims);
name = "euclidean"; name = "euclidean";
exp = Nd4j.getExecutioner().exec(new EuclideanDistance(inArr, in2Arr, null, true, false, reduceDims)); exp = Nd4j.getExecutioner().exec(new EuclideanDistance(inArr, in2Arr, null, false, false, reduceDims));
break; break;
case 2: case 2:
inArr.muli(1e-4); inArr.muli(1e-4);
in2Arr.muli(1e-4); in2Arr.muli(1e-4);
reduced = sd.math().cosineSimilarity(in, in2, reduceDims); reduced = sd.math().cosineSimilarity(in, in2, reduceDims);
name = "cosine"; name = "cosine";
exp = Nd4j.getExecutioner().exec(new CosineSimilarity(inArr, in2Arr, null, true, false, reduceDims)); exp = Nd4j.getExecutioner().exec(new CosineSimilarity(inArr, in2Arr, null, false, false, reduceDims));
maxRelError = 1e-4;
break; break;
case 3: case 3:
reduced = sd.math().cosineDistance(in, in2, reduceDims); reduced = sd.math().cosineDistance(in, in2, reduceDims);
name = "cosinedistance"; name = "cosinedistance";
exp = Nd4j.getExecutioner().exec(new CosineDistance(inArr, in2Arr, null, true, false, reduceDims)); exp = Nd4j.getExecutioner().exec(new CosineDistance(inArr, in2Arr, null, false, false, reduceDims));
maxRelError = 1e-4;
break; break;
case 4: case 4:
reduced = sd.math().hammingDistance(in, in2, reduceDims); reduced = sd.math().hammingDistance(in, in2, reduceDims);
name = "hamming"; name = "hamming";
exp = Nd4j.getExecutioner().exec(new HammingDistance(inArr, in2Arr, null, true, false, reduceDims)); exp = Nd4j.getExecutioner().exec(new HammingDistance(inArr, in2Arr, null, false, false, reduceDims));
break; break;
case 5: case 5:
name = "jaccard"; name = "jaccard";
reduced = sd.math().jaccardDistance(name, in, in2, reduceDims); reduced = sd.math().jaccardDistance(name, in, in2, reduceDims);
inArr.divi(100).addi(0.1); inArr.divi(100).addi(0.1);
in2Arr.divi(100).addi(0.1); in2Arr.divi(100).addi(0.1);
exp = Nd4j.getExecutioner().exec(new JaccardDistance(inArr, in2Arr, null, true, false, reduceDims)); exp = Nd4j.getExecutioner().exec(new JaccardDistance(inArr, in2Arr, null, false, false, reduceDims));
if (OpValidationSuite.IGNORE_FAILING && reduceDims.length == 2) if (OpValidationSuite.IGNORE_FAILING && reduceDims.length == 2)
continue; continue;
@ -708,6 +723,9 @@ public class ReductionOpValidation extends BaseOpValidation {
tc.expected(reduced, exp); tc.expected(reduced, exp);
if(maxRelError != null)
tc.gradCheckMaxRelativeError(maxRelError);
String error = OpValidation.validate(tc, true); String error = OpValidation.validate(tc, true);
if (error != null) { if (error != null) {
failed.add(msg + " - " + error); failed.add(msg + " - " + error);
@ -768,7 +786,6 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
@Ignore("AB 2019/06/24 - Failing: Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912")
public void testNormalizeMomentsOp() { public void testNormalizeMomentsOp() {
INDArray data = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10); INDArray data = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10);
INDArray ssSum = data.sum(0); INDArray ssSum = data.sum(0);
@ -780,7 +797,7 @@ public class ReductionOpValidation extends BaseOpValidation {
INDArray mean = Nd4j.createUninitialized(DataType.DOUBLE, meanExp.shape()); INDArray mean = Nd4j.createUninitialized(DataType.DOUBLE, meanExp.shape());
INDArray var = Nd4j.createUninitialized(DataType.DOUBLE, varExp.shape()); INDArray var = Nd4j.createUninitialized(DataType.DOUBLE, varExp.shape());
OpTestCase op = new OpTestCase(new NormalizeMoments(Nd4j.scalar(DataType.INT, 10), ssSum, ssSqSum, mean, var)); OpTestCase op = new OpTestCase(new NormalizeMoments(Nd4j.scalar(DataType.DOUBLE, 10), ssSum, ssSqSum, mean, var));
op.expectedOutput(0, meanExp); op.expectedOutput(0, meanExp);
op.expectedOutput(1, varExp); op.expectedOutput(1, varExp);
@ -821,7 +838,7 @@ public class ReductionOpValidation extends BaseOpValidation {
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
List<int[]> dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1}, new int[0]); List<int[]> dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1}, new int[0]);
INDArray in = Nd4j.rand(3, 4); INDArray in = Nd4j.rand(DataType.DOUBLE,3, 4);
for (int t = 0; t < 4; t++) { for (int t = 0; t < 4; t++) {
int[] d = dims.get(t); int[] d = dims.get(t);
@ -838,52 +855,47 @@ public class ReductionOpValidation extends BaseOpValidation {
switch (i) { switch (i) {
case 0: case 0:
reduce = s.argmax(dim); reduce = s.argmax(dim);
exp = Nd4j.argMax(in, dim).castTo(DataType.DOUBLE); exp = Nd4j.argMax(in, dim);
name = "argmax"; name = "argmax";
break; break;
case 1: case 1:
reduce = s.argmin(dim); reduce = s.argmin(dim);
exp = Nd4j.argMin(in, dim).castTo(DataType.DOUBLE); exp = Nd4j.argMin(in, dim);
name = "argmin"; name = "argmin";
break; break;
case 2: case 2:
reduce = sd.math().iamax(s, dim); reduce = sd.math().iamax(s, dim);
exp = Nd4j.getExecutioner().exec(new IAMax(in.dup(), dim)); exp = Nd4j.getExecutioner().exec(new IAMax(in.dup(), dim));
exp = exp.castTo(DataType.DOUBLE);
name = "iamax"; name = "iamax";
break; break;
case 3: case 3:
reduce = sd.math().iamin(s, dim); reduce = sd.math().iamin(s, dim);
exp = Nd4j.getExecutioner().exec(new IAMin(in.dup(), dim)); exp = Nd4j.getExecutioner().exec(new IAMin(in.dup(), dim));
exp = exp.castTo(DataType.DOUBLE);
name = "iamin"; name = "iamin";
break; break;
case 4: case 4:
reduce = sd.math().firstIndex(s, Conditions.greaterThan(0), dim); reduce = sd.math().firstIndex(s, Conditions.greaterThan(0), dim);
exp = in.sum(dim).assign(0); exp = in.sum(dim).assign(0).castTo(DataType.INT64);
exp = exp.castTo(DataType.DOUBLE);
name = "firstindex"; name = "firstindex";
break; break;
case 5: case 5:
reduce = sd.math().lastIndex(s, Conditions.greaterThan(0), dim); reduce = sd.math().lastIndex(s, Conditions.greaterThan(0), dim);
if (t == 0) exp = Nd4j.create(new double[]{2, 2, 2, 2}); if (t == 0) exp = Nd4j.createFromArray(2L, 2, 2, 2);
else if (t == 1) exp = Nd4j.create(new double[]{3, 3, 3}); else if (t == 1) exp = Nd4j.createFromArray(3L, 3, 3);
else exp = Nd4j.scalar(11.0); else exp = Nd4j.scalar(11L);
exp = exp.castTo(DataType.DOUBLE);
name = "lastindex"; name = "lastindex";
break; break;
case 6: case 6:
reduce = sd.matchConditionCount("count", s, Conditions.greaterThan(0), false, dim); reduce = sd.matchConditionCount("count", s, Conditions.greaterThan(0), false, dim);
if (t == 0) exp = Nd4j.create(new double[]{3, 3, 3, 3}); if (t == 0) exp = Nd4j.createFromArray(3L, 3, 3, 3);
else if (t == 1) exp = Nd4j.create(new double[]{4, 4, 4}); else if (t == 1) exp = Nd4j.createFromArray(4L, 4, 4);
else exp = Nd4j.scalar(12.0); else exp = Nd4j.scalar(12L);
exp = exp.castTo(DataType.DOUBLE);
name = "matchConditionCount"; name = "matchConditionCount";
break; break;
default: default:
throw new RuntimeException(); throw new RuntimeException();
} }
SDVariable preCast = reduce;
reduce = reduce.castTo(DataType.DOUBLE); reduce = reduce.castTo(DataType.DOUBLE);
SDVariable loss; SDVariable loss;
@ -894,7 +906,7 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
TestCase tc = new TestCase(sd) TestCase tc = new TestCase(sd)
.expected(reduce, exp) .expected(preCast, exp)
.gradientCheck(false) .gradientCheck(false)
.testName(name + " - " + (dim == null ? null : Arrays.toString(dim))); .testName(name + " - " + (dim == null ? null : Arrays.toString(dim)));
@ -1335,4 +1347,254 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
} }
} }
@Test
public void testSufficientStatisticsOp() {
INDArray data = Nd4j.createFromArray(new double[]{
5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1.,
1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5
}).reshape(2,2,2,4);
INDArray axes = Nd4j.linspace(DataType.LONG, 0, 3, 1);
OpTestCase op = new OpTestCase(new SufficientStatistics(data, axes));
INDArray expected1 = Nd4j.scalar(8.0);
INDArray expected2 = Nd4j.createFromArray(new double[]{
30.2, 5., 7.8, 22.8
});
INDArray expected3 = Nd4j.createFromArray(new double[]{
154.22, 7., 14.34, 103.62
});
op.expectedOutput(0, expected1);
op.expectedOutput(1, expected2);
op.expectedOutput(2, expected3);
String err = OpValidation.validate(op);
assertNull(err);
}
@Test
public void testStandardDeviation() {
for (boolean keepDims : new boolean[]{false, true}) {
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(1, 8, 8).reshape(2, 4);
SDVariable input = sameDiff.var(in);
INDArray expected = Nd4j.createFromArray(new double[]{
2, 2, 2, 2
});
if(keepDims){
expected = expected.reshape(1,4);
}
SDVariable output = new StandardDeviation(sameDiff, input, false, keepDims, new int[]{0}).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
}
@Test
public void testSquaredNorm() {
for (boolean keepDims : new boolean[]{false, true}) {
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(1, 4, 4);
SDVariable input = sameDiff.var(in);
INDArray expected = Nd4j.scalar(30.0000);
if(keepDims)
expected = expected.reshape(1);
SDVariable output = new SquaredNorm(sameDiff, input, keepDims, new int[]{0}).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
}
@Test
public void testShannonEntropy() {
OpValidationSuite.ignoreFailing(); //AB 2020/02/11 https://github.com/eclipse/deeplearning4j/issues/8695
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(1, 4, 4).castTo(DataType.DOUBLE);
SDVariable input = sameDiff.var(in);
INDArray expected = Nd4j.scalar(-69.68162);
SDVariable output = new ShannonEntropy(sameDiff, input, new int[]{0}).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testEntropy() {
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(1, 4, 4);
SDVariable input = sameDiff.var(in);
double expected = -10.2273;
SDVariable output = new Entropy(sameDiff, input, new int[]{0}).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), Nd4j.scalar(expected));
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testAMean() {
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
SDVariable input = sameDiff.var(in);
INDArray expected = Nd4j.createFromArray(new double[]{
5.0000, 6.0000, 7.0000, 8.0000
});
SDVariable output = new AMean(sameDiff, input, new int[]{0}).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testMean() {
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
SDVariable input = sameDiff.var(in);
INDArray expected = Nd4j.createFromArray(new double[]{
5.0000, 6.0000, 7.0000, 8.0000
});
SDVariable output = new Mean(sameDiff, input, false, new int[]{0}).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testNorm1() {
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
SDVariable input = sameDiff.var(in);
INDArray expected = Nd4j.createFromArray(new double[]{
15.0000, 18.0000, 21.0000, 24.0000
});
SDVariable output = new Norm1(sameDiff, input, false, new int[]{0}).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testNorm2() {
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
SDVariable input = sameDiff.var(in);
INDArray expected = Nd4j.createFromArray(new double[]{
10.3441, 11.8322, 13.3791, 14.9666
});
SDVariable output = new Norm2(sameDiff, input, false, new int[]{0}).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testNormMax() {
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
SDVariable input = sameDiff.var(in);
INDArray expected = Nd4j.createFromArray(new double[]{
9.0000, 10.0000, 11.0000, 12.0000
});
SDVariable output = new NormMax(sameDiff, input, false, new int[]{0}).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testSoftmaxCrossEntropyWithLogitsLoss() {
OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create();
INDArray labels = Nd4j.createFromArray(new double[]{
0,1,1,0,0,0,1,0,1,0,1,1,1,0,1,0,1,0,0,1,1,0,1,0
}).reshape(2,3,4);
INDArray logits = Nd4j.linspace(DataType.DOUBLE, 0.1, 0.1, 24).reshape(2,3,4);
INDArray expected = Nd4j.createFromArray(new double[]{
0.26328, 1.46328, 1.72656, 0. , 0.26328, 0. , 1.46328, 0.26328, 1.72656, 0. , 1.72656, 1.46328
}).reshape(3,4);
SDVariable sdLogits = sameDiff.var("logits", logits);
SDVariable sdLabels = sameDiff.var("labels", labels);
SDVariable loss = sameDiff.math().abs(sdLogits);
SDVariable output = new SoftmaxCrossEntropyWithLogitsLoss(sameDiff, sdLogits, sdLabels, 0).outputVariable();
sameDiff.setLossVariables(output);
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
} }

View File

@ -1016,7 +1016,7 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testConstant(){ public void testConstant(){
OpValidationSuite.ignoreFailing(); //OpValidationSuite.ignoreFailing();
//Case 0: no shape //Case 0: no shape
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1035,7 +1035,9 @@ public class ShapeOpValidation extends BaseOpValidation {
INDArray exp = Nd4j.valueArrayOf(new long[]{3,4,5}, 3.0); INDArray exp = Nd4j.valueArrayOf(new long[]{3,4,5}, 3.0);
loss = constant.std(true); loss = constant.std(true);
assertNull(OpValidation.validate(new TestCase(sd).expected(constant, ia))); assertNull(OpValidation.validate(new TestCase(sd)
.gradientCheck(false)
.expected(constant, Nd4j.create(DataType.FLOAT, 3,4,5))));
} }
@ -1272,7 +1274,7 @@ public class ShapeOpValidation extends BaseOpValidation {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable data = sd.var("data", d); SDVariable data = sd.var("data", d);
SDVariable segments = sd.var("segments", s); SDVariable segments = sd.constant("segments", s);
SDVariable sm; SDVariable sm;
INDArray exp; INDArray exp;
@ -1326,6 +1328,7 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
SDVariable loss = sm.std(true); SDVariable loss = sm.std(true);
sd.addLossVariable(loss);
TestCase tc = new TestCase(sd) TestCase tc = new TestCase(sd)
.testName(op) .testName(op)
@ -1363,17 +1366,19 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testSequenceMask() { public void testSequenceMask() {
OpValidationSuite.ignoreFailing(); //2018-01-09: output datatype issue?
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Nd4j.create(new float[] {1, 3, 2}).reshape(3); INDArray arr = Nd4j.createFromArray(new int[] {1, 3, 2});
SDVariable lengths = sameDiff.var("lengths", arr); // arr is not trainable, so it's constant in model
SDVariable lengths = sameDiff.constant(arr);
// Test with static max len // Test with static max len
int maxlen = 5; int maxlen = 5;
INDArray expected = Nd4j.create(new float[] {1, 0, 0, 0, 0, INDArray expected = Nd4j.create(new float[] {
1, 1, 1, 0, 0, 1.f, 0.f, 0.f, 0.f, 0.f,
1, 1, 0, 0, 0}, 1.f, 1.f, 1.f, 0.f, 0.f,
new long[]{3, 5}); 1.f, 1.f, 0.f, 0.f, 0.f
}).reshape(3,5);
INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.FLOAT));
SDVariable result1 = sameDiff.sequenceMask(lengths, maxlen, DataType.FLOAT); SDVariable result1 = sameDiff.sequenceMask(lengths, maxlen, DataType.FLOAT);
assertArrayEquals(expected.shape(), result1.eval().shape()); assertArrayEquals(expected.shape(), result1.eval().shape());
assertEquals(expected, result1.eval()); assertEquals(expected, result1.eval());
@ -1382,14 +1387,14 @@ public class ShapeOpValidation extends BaseOpValidation {
String err = OpValidation.validate(new TestCase(sameDiff) String err = OpValidation.validate(new TestCase(sameDiff)
.expected(result1, expected) .expected(result1, expected)
.gradCheckSkipVariables(lengths.name())); .gradientCheck(false));
assertNull(err); assertNull(err);
// Test with dynamic maxlen // Test with dynamic maxlen
lengths = sameDiff.var("lengths2", arr); // required because of an internal samediff bug lengths = sameDiff.constant("lengths2", arr);
SDVariable maxLen = sameDiff.var("maxLen", Nd4j.create(new float[]{5}).reshape(1)); SDVariable maxLen = sameDiff.constant("maxLen", Nd4j.scalar(5));
SDVariable result2 = sameDiff.sequenceMask(lengths, maxLen, DataType.FLOAT); SDVariable result2 = sameDiff.sequenceMask(lengths, maxLen, DataType.FLOAT);
assertArrayEquals(expected.shape(), result2.eval().shape()); // assertArrayEquals(expected.shape(), result2.eval().shape());
assertEquals(expected, result2.eval()); assertEquals(expected, result2.eval());
} }

View File

@ -303,7 +303,7 @@ public class TransformOpValidation extends BaseOpValidation {
@Test @Test
public void testBatchToSpace() { public void testBatchToSpace() {
OpValidationSuite.ignoreFailing(); //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863 //OpValidationSuite.ignoreFailing(); //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863
Nd4j.getRandom().setSeed(1337); Nd4j.getRandom().setSeed(1337);
int miniBatch = 4; int miniBatch = 4;
@ -314,7 +314,6 @@ public class TransformOpValidation extends BaseOpValidation {
int[] cropShape = new int[]{M, 2}; int[] cropShape = new int[]{M, 2};
INDArray input = Nd4j.randn(inputShape).castTo(DataType.DOUBLE); INDArray input = Nd4j.randn(inputShape).castTo(DataType.DOUBLE);
INDArray blocks = Nd4j.create(new float[]{2, 2}, blockShape).castTo(DataType.INT);
INDArray crops = Nd4j.create(new float[]{0, 0, 0, 0}, cropShape).castTo(DataType.INT); INDArray crops = Nd4j.create(new float[]{0, 0, 0, 0}, cropShape).castTo(DataType.INT);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -323,7 +322,8 @@ public class TransformOpValidation extends BaseOpValidation {
INDArray expOut = Nd4j.create(DataType.DOUBLE, 1, 2, 2, 1); INDArray expOut = Nd4j.create(DataType.DOUBLE, 1, 2, 2, 1);
DynamicCustomOp op = DynamicCustomOp.builder("batch_to_space") DynamicCustomOp op = DynamicCustomOp.builder("batch_to_space")
.addInputs(input, blocks, crops) .addInputs(input, crops)
.addIntegerArguments(2)
.addOutputs(expOut).build(); .addOutputs(expOut).build();
Nd4j.getExecutioner().exec(op); Nd4j.getExecutioner().exec(op);
@ -340,7 +340,7 @@ public class TransformOpValidation extends BaseOpValidation {
@Test @Test
public void testSpaceToBatch() { public void testSpaceToBatch() {
OpValidationSuite.ignoreFailing(); //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863 //OpValidationSuite.ignoreFailing(); //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863
Nd4j.getRandom().setSeed(7331); Nd4j.getRandom().setSeed(7331);
@ -352,7 +352,6 @@ public class TransformOpValidation extends BaseOpValidation {
int[] paddingShape = new int[]{M, 2}; int[] paddingShape = new int[]{M, 2};
INDArray input = Nd4j.randn(inputShape).castTo(DataType.DOUBLE); INDArray input = Nd4j.randn(inputShape).castTo(DataType.DOUBLE);
INDArray blocks = Nd4j.create(new float[]{2, 2}, blockShape).castTo(DataType.INT);
INDArray padding = Nd4j.create(new float[]{0, 0, 0, 0}, paddingShape).castTo(DataType.INT); INDArray padding = Nd4j.create(new float[]{0, 0, 0, 0}, paddingShape).castTo(DataType.INT);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -361,7 +360,8 @@ public class TransformOpValidation extends BaseOpValidation {
INDArray expOut = Nd4j.create(DataType.DOUBLE, miniBatch, 1, 1, 1); INDArray expOut = Nd4j.create(DataType.DOUBLE, miniBatch, 1, 1, 1);
DynamicCustomOp op = DynamicCustomOp.builder("space_to_batch") DynamicCustomOp op = DynamicCustomOp.builder("space_to_batch")
.addInputs(input, blocks, padding) .addIntegerArguments(2)
.addInputs(input, padding)
.addOutputs(expOut).build(); .addOutputs(expOut).build();
Nd4j.getExecutioner().exec(op); Nd4j.getExecutioner().exec(op);

View File

@ -37,6 +37,7 @@ import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
import org.nd4j.linalg.api.ops.impl.shape.Create; import org.nd4j.linalg.api.ops.impl.shape.Create;
import org.nd4j.linalg.api.ops.impl.shape.OnesLike; import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
import org.nd4j.linalg.api.ops.impl.shape.SequenceMask;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp;
@ -1737,4 +1738,19 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(expected, ret[0]); assertEquals(expected, ret[0]);
} }
@Test
public void testSequenceMask() {
INDArray arr = Nd4j.createFromArray(new int[]{1, 3, 2});
// Test with static max len
int maxlen = 2;
INDArray expected = Nd4j.createFromArray(new int[]{
1,0,0,
1,1,1,
1,1,0
}).reshape(3, 3);
INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.INT32));
assertEquals(expected, ret[0]);
}
} }

View File

@ -318,8 +318,8 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
"num2Scalar" should "convert number to Scalar INDArray" in { "num2Scalar" should "convert number to Scalar INDArray" in {
assert(1.toScalar.data() == List(1).toNDArray.data()) assert(1.toScalar.reshape(1) == List(1).toNDArray)
assert(2f.toScalar.data() == List(2).toNDArray.data()) assert(2f.toScalar.reshape(1) == List(2f).toNDArray)
assert(3d.toScalar.data() == List(3).toNDArray.data()) assert(3d.toScalar.reshape(1) == List(3d).toNDArray)
} }
} }