Indexing demo

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
master
Alexander Stoyakin 2019-10-07 16:59:47 +03:00
parent b8f2a83a5a
commit 7a1e327516
2 changed files with 23 additions and 1 deletions

View File

@ -44,6 +44,9 @@ class SameDiffWrapper {
def bind(name: String, dataType: DataType, shape: Array[Long]): SDVariable =
sd.`var`(name, dataType, shape: _*)
def bind(data: INDArray): SDVariable =
sd.`var`("", data)
def bind(name: String, dataType: DataType, shape: Array[Int]): SDVariable =
sd.`var`(name, dataType, shape: _*)
@ -63,6 +66,16 @@ class SDVariableWrapper {
def apply(index: Long): SDVariable = thisVariable.get(SDIndex.point(index))
def apply(index: SDIndex*): SDVariable = thisVariable.get(index: _*)
def apply(x: SDIndex)(y: SDIndex): SDVariable =
(x, y) match {
case (_, y) => thisVariable.get(SDIndex.all(), y)
case (x, _) => thisVariable.get(x, SDIndex.all())
case (_, _) => thisVariable.get(SDIndex.all(), SDIndex.all())
case (x, y) => thisVariable.get(x, y)
}
def add(other: Double): Unit = thisVariable.add(other)
def *(other: SDVariable): SDVariable =

View File

@ -205,9 +205,18 @@ class MathTest extends FlatSpec with Matchers {
implicit val sd = SameDiff.create
val arr = Nd4j.linspace(1, 100, 100).reshape('c', 10L, 10L)
val x = sd.`var`(arr)
val x = sd.bind(arr)
val y = new SDVariableWrapper(x)
x.get(SDIndex.point(0)).getArr shouldBe y(0).getArr
}
"SDVariable " should "be indexable in 2d" in {
implicit val sd = SameDiff.create
val arr = Nd4j.rand(2, 5)
val x = sd.bind(arr)
println(x(SDIndex.point(0), _: SDIndex).getArr)
}
}