diff --git a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala index fc70b6d3f..6a8924365 100644 --- a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala +++ b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala @@ -44,6 +44,9 @@ class SameDiffWrapper { def bind(name: String, dataType: DataType, shape: Array[Long]): SDVariable = sd.`var`(name, dataType, shape: _*) + def bind(data: INDArray): SDVariable = + sd.`var`("", data) + def bind(name: String, dataType: DataType, shape: Array[Int]): SDVariable = sd.`var`(name, dataType, shape: _*) @@ -63,6 +66,16 @@ class SDVariableWrapper { def apply(index: Long): SDVariable = thisVariable.get(SDIndex.point(index)) + def apply(index: SDIndex*): SDVariable = thisVariable.get(index: _*) + + def apply(x: SDIndex)(y: SDIndex): SDVariable = + (x, y) match { + case (_, y) => thisVariable.get(SDIndex.all(), y) + case (x, _) => thisVariable.get(x, SDIndex.all()) + case (_, _) => thisVariable.get(SDIndex.all(), SDIndex.all()) + case (x, y) => thisVariable.get(x, y) + } + def add(other: Double): Unit = thisVariable.add(other) def *(other: SDVariable): SDVariable = diff --git a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala index 104356f24..50a77fc14 100644 --- a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala +++ b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala @@ -205,9 +205,18 @@ class MathTest extends FlatSpec with Matchers { implicit val sd = SameDiff.create val arr = Nd4j.linspace(1, 100, 100).reshape('c', 10L, 10L) - val x = sd.`var`(arr) + val x = sd.bind(arr) val y = new SDVariableWrapper(x) x.get(SDIndex.point(0)).getArr shouldBe y(0).getArr } + + "SDVariable " should "be indexable in 2d" in { + implicit val sd = SameDiff.create + + val arr = Nd4j.rand(2, 5) + val x = sd.bind(arr) + + println(x(SDIndex.point(0), _: SDIndex).getArr) + } }