parent
b8f2a83a5a
commit
7a1e327516
|
@ -44,6 +44,9 @@ class SameDiffWrapper {
|
|||
def bind(name: String, dataType: DataType, shape: Array[Long]): SDVariable =
|
||||
sd.`var`(name, dataType, shape: _*)
|
||||
|
||||
def bind(data: INDArray): SDVariable =
|
||||
sd.`var`("", data)
|
||||
|
||||
def bind(name: String, dataType: DataType, shape: Array[Int]): SDVariable =
|
||||
sd.`var`(name, dataType, shape: _*)
|
||||
|
||||
|
@ -63,6 +66,16 @@ class SDVariableWrapper {
|
|||
|
||||
def apply(index: Long): SDVariable = thisVariable.get(SDIndex.point(index))
|
||||
|
||||
def apply(index: SDIndex*): SDVariable = thisVariable.get(index: _*)
|
||||
|
||||
def apply(x: SDIndex)(y: SDIndex): SDVariable =
|
||||
(x, y) match {
|
||||
case (_, y) => thisVariable.get(SDIndex.all(), y)
|
||||
case (x, _) => thisVariable.get(x, SDIndex.all())
|
||||
case (_, _) => thisVariable.get(SDIndex.all(), SDIndex.all())
|
||||
case (x, y) => thisVariable.get(x, y)
|
||||
}
|
||||
|
||||
def add(other: Double): Unit = thisVariable.add(other)
|
||||
|
||||
def *(other: SDVariable): SDVariable =
|
||||
|
|
|
@ -205,9 +205,18 @@ class MathTest extends FlatSpec with Matchers {
|
|||
implicit val sd = SameDiff.create
|
||||
|
||||
val arr = Nd4j.linspace(1, 100, 100).reshape('c', 10L, 10L)
|
||||
val x = sd.`var`(arr)
|
||||
val x = sd.bind(arr)
|
||||
val y = new SDVariableWrapper(x)
|
||||
|
||||
x.get(SDIndex.point(0)).getArr shouldBe y(0).getArr
|
||||
}
|
||||
|
||||
"SDVariable " should "be indexable in 2d" in {
|
||||
implicit val sd = SameDiff.create
|
||||
|
||||
val arr = Nd4j.rand(2, 5)
|
||||
val x = sd.bind(arr)
|
||||
|
||||
println(x(SDIndex.point(0), _: SDIndex).getArr)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue