From 7a1e32751694511fbb0f4f30c14165aa0dc72eb1 Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Mon, 7 Oct 2019 16:59:47 +0300 Subject: [PATCH 1/4] Indexing demo Signed-off-by: Alexander Stoyakin --- .../src/main/scala/org/nd4s/samediff/SameDiff.scala | 13 +++++++++++++ .../src/test/scala/org/nd4s/samediff/MathTest.scala | 11 ++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala index fc70b6d3f..6a8924365 100644 --- a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala +++ b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala @@ -44,6 +44,9 @@ class SameDiffWrapper { def bind(name: String, dataType: DataType, shape: Array[Long]): SDVariable = sd.`var`(name, dataType, shape: _*) + def bind(data: INDArray): SDVariable = + sd.`var`("", data) + def bind(name: String, dataType: DataType, shape: Array[Int]): SDVariable = sd.`var`(name, dataType, shape: _*) @@ -63,6 +66,16 @@ class SDVariableWrapper { def apply(index: Long): SDVariable = thisVariable.get(SDIndex.point(index)) + def apply(index: SDIndex*): SDVariable = thisVariable.get(index: _*) + + def apply(x: SDIndex)(y: SDIndex): SDVariable = + (x, y) match { + case (_, y) => thisVariable.get(SDIndex.all(), y) + case (x, _) => thisVariable.get(x, SDIndex.all()) + case (_, _) => thisVariable.get(SDIndex.all(), SDIndex.all()) + case (x, y) => thisVariable.get(x, y) + } + def add(other: Double): Unit = thisVariable.add(other) def *(other: SDVariable): SDVariable = diff --git a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala index 104356f24..50a77fc14 100644 --- a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala +++ b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala @@ -205,9 +205,18 @@ class MathTest extends FlatSpec with Matchers { implicit val sd = SameDiff.create val arr = Nd4j.linspace(1, 100, 100).reshape('c', 10L, 10L) - val x = sd.`var`(arr) + val x = sd.bind(arr) val y = new SDVariableWrapper(x) x.get(SDIndex.point(0)).getArr shouldBe y(0).getArr } + + "SDVariable " should "be indexable in 2d" in { + implicit val sd = SameDiff.create + + val arr = Nd4j.rand(2, 5) + val x = sd.bind(arr) + + println(x(SDIndex.point(0), _: SDIndex).getArr) + } } From b36b3fa1aa8a9be7a809bfcfb4ca8843aea630c5 Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Tue, 8 Oct 2019 16:41:12 +0300 Subject: [PATCH 2/4] Some sugar Signed-off-by: Alexander Stoyakin --- .../main/scala/org/nd4s/samediff/SameDiff.scala | 15 +++++++++++++-- .../org/nd4s/samediff/implicits/Implicits.scala | 9 +++++++-- .../test/scala/org/nd4s/samediff/MathTest.scala | 16 ++++++++++++++-- 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala index 6a8924365..de14258b9 100644 --- a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala +++ b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala @@ -54,10 +54,21 @@ class SameDiffWrapper { sd.placeHolder(name, dataType, shape: _*) } +case class SDIndexWrapper(start: Long) { + + def ::(end: Long): SDIndex = + SDIndex.interval(start, end) +} + +object --- extends SDIndex { + val thisIndex: SDIndex = SDIndex.all() +} + class SDVariableWrapper { var thisVariable: SDVariable = null var isScalar: Boolean = false + val --- : SDIndex = SDIndex.all() def this(variable: SDVariable) { this @@ -68,13 +79,13 @@ class SDVariableWrapper { def apply(index: SDIndex*): SDVariable = thisVariable.get(index: _*) - def apply(x: SDIndex)(y: SDIndex): SDVariable = + /*def apply(x: SDIndex, y: SDIndex): SDVariable = (x, y) match { case (_, y) => thisVariable.get(SDIndex.all(), y) case (x, _) => thisVariable.get(x, SDIndex.all()) case (_, _) => thisVariable.get(SDIndex.all(), SDIndex.all()) case (x, y) => thisVariable.get(x, y) - } + }*/ def add(other: Double): Unit = thisVariable.add(other) diff --git a/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala b/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala index c10ff367c..0673cb3d8 100644 --- a/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala +++ b/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala @@ -15,9 +15,9 @@ ******************************************************************************/ package org.nd4s.samediff.implicits -import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff } +import org.nd4j.autodiff.samediff.{ SDIndex, SDVariable, SameDiff } import org.nd4j.linalg.factory.Nd4j -import org.nd4s.samediff.{ SDVariableWrapper, SameDiffWrapper } +import org.nd4s.samediff.{ SDIndexWrapper, SDVariableWrapper, SameDiffWrapper } object Implicits { implicit def SameDiffToWrapper(sd: SameDiff): SameDiffWrapper = @@ -43,4 +43,9 @@ object Implicits { result.isScalar = true result } + + implicit def RangeToWrapper(start: Long): SDIndexWrapper = { + val result = new SDIndexWrapper(start) + result + } } diff --git a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala index 50a77fc14..c82d1b45c 100644 --- a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala +++ b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala @@ -20,6 +20,7 @@ import org.nd4j.linalg.api.buffer.DataType import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.factory.Nd4j import org.nd4s.Implicits._ +import org.nd4s.NDOrdering import org.nd4s.samediff.implicits.Implicits._ import org.scalatest.{ FlatSpec, Matchers } @@ -214,9 +215,20 @@ class MathTest extends FlatSpec with Matchers { "SDVariable " should "be indexable in 2d" in { implicit val sd = SameDiff.create - val arr = Nd4j.rand(2, 5) + val arr = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3, 3) + val x = sd.bind(arr) - println(x(SDIndex.point(0), _: SDIndex).getArr) + println(x(SDIndex.point(0), ---).getArr) + + val data1 = x(SDIndex.interval(0: Long, 2: Long), ---).getArr + println(data1) + + val data2 = x(0 :: 2, ---).getArr + println(data2) + + //assert(indices.indices == List(0, 1, 2, 3, 4, 5)) + //assert(indices.targetShape.toList == List(2, 3)) + } } From 4956048a11b9d439fc72405cdb52796655b391fb Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Tue, 8 Oct 2019 17:38:43 +0300 Subject: [PATCH 3/4] Tests added Signed-off-by: Alexander Stoyakin --- .../main/scala/org/nd4s/samediff/SameDiff.scala | 14 ++++---------- .../org/nd4s/samediff/implicits/Implicits.scala | 3 +++ .../test/scala/org/nd4s/samediff/MathTest.scala | 17 ++++++++++------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala index de14258b9..801309c85 100644 --- a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala +++ b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala @@ -54,9 +54,9 @@ class SameDiffWrapper { sd.placeHolder(name, dataType, shape: _*) } -case class SDIndexWrapper(start: Long) { +case class SDIndexWrapper(end: Long) { - def ::(end: Long): SDIndex = + def ::(start: Long): SDIndex = SDIndex.interval(start, end) } @@ -75,18 +75,12 @@ class SDVariableWrapper { thisVariable = variable } + // Indexing def apply(index: Long): SDVariable = thisVariable.get(SDIndex.point(index)) def apply(index: SDIndex*): SDVariable = thisVariable.get(index: _*) - /*def apply(x: SDIndex, y: SDIndex): SDVariable = - (x, y) match { - case (_, y) => thisVariable.get(SDIndex.all(), y) - case (x, _) => thisVariable.get(x, SDIndex.all()) - case (_, _) => thisVariable.get(SDIndex.all(), SDIndex.all()) - case (x, y) => thisVariable.get(x, y) - }*/ - + // Arithmetic def add(other: Double): Unit = thisVariable.add(other) def *(other: SDVariable): SDVariable = diff --git a/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala b/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala index 0673cb3d8..004cb6703 100644 --- a/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala +++ b/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala @@ -48,4 +48,7 @@ object Implicits { val result = new SDIndexWrapper(start) result } + + implicit def LongToPoint(x: Long): SDIndex = + SDIndex.point(x) } diff --git a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala index c82d1b45c..890f70853 100644 --- a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala +++ b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala @@ -219,16 +219,19 @@ class MathTest extends FlatSpec with Matchers { val x = sd.bind(arr) - println(x(SDIndex.point(0), ---).getArr) + x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval - val data1 = x(SDIndex.interval(0: Long, 2: Long), ---).getArr - println(data1) + val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval + val slice2 = x(0 :: 2, ---).eval + slice1 shouldBe slice2 + } - val data2 = x(0 :: 2, ---).getArr - println(data2) + "SDVariable " should "be indexable in 3d" in { + implicit val sd = SameDiff.create - //assert(indices.indices == List(0, 1, 2, 3, 4, 5)) - //assert(indices.targetShape.toList == List(2, 3)) + val arr = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 18).reshape(3, 3, 2) + val x = sd.bind(arr) + println(x(0,0,SDIndex.all()).eval) } } From 8472782d7f721ef67f47fe2607ad21a030fe9b3a Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Tue, 8 Oct 2019 19:08:43 +0300 Subject: [PATCH 4/4] More tests Signed-off-by: Alexander Stoyakin --- nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala index 890f70853..dc41b31f6 100644 --- a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala +++ b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala @@ -232,6 +232,15 @@ class MathTest extends FlatSpec with Matchers { val arr = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 18).reshape(3, 3, 2) val x = sd.bind(arr) - println(x(0,0,SDIndex.all()).eval) + x.get(SDIndex.all(), SDIndex.all(), SDIndex.all()).eval shouldBe x(---, ---, ---).eval + x.get(SDIndex.point(0), SDIndex.all(), SDIndex.all()).eval shouldBe x(0, ---, ---).eval + x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval + x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval + + x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 :: 2, 0, 0).eval + x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 :: 2, + 0 :: 1, + 0 :: 2).eval + x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 :: 2, 0 :: 1, ---).eval } }