parent
b36b3fa1aa
commit
4956048a11
|
@ -54,9 +54,9 @@ class SameDiffWrapper {
|
|||
sd.placeHolder(name, dataType, shape: _*)
|
||||
}
|
||||
|
||||
case class SDIndexWrapper(start: Long) {
|
||||
case class SDIndexWrapper(end: Long) {
|
||||
|
||||
def ::(end: Long): SDIndex =
|
||||
def ::(start: Long): SDIndex =
|
||||
SDIndex.interval(start, end)
|
||||
}
|
||||
|
||||
|
@ -75,18 +75,12 @@ class SDVariableWrapper {
|
|||
thisVariable = variable
|
||||
}
|
||||
|
||||
// Indexing
|
||||
def apply(index: Long): SDVariable = thisVariable.get(SDIndex.point(index))
|
||||
|
||||
def apply(index: SDIndex*): SDVariable = thisVariable.get(index: _*)
|
||||
|
||||
/*def apply(x: SDIndex, y: SDIndex): SDVariable =
|
||||
(x, y) match {
|
||||
case (_, y) => thisVariable.get(SDIndex.all(), y)
|
||||
case (x, _) => thisVariable.get(x, SDIndex.all())
|
||||
case (_, _) => thisVariable.get(SDIndex.all(), SDIndex.all())
|
||||
case (x, y) => thisVariable.get(x, y)
|
||||
}*/
|
||||
|
||||
// Arithmetic
|
||||
def add(other: Double): Unit = thisVariable.add(other)
|
||||
|
||||
def *(other: SDVariable): SDVariable =
|
||||
|
|
|
@ -48,4 +48,7 @@ object Implicits {
|
|||
val result = new SDIndexWrapper(start)
|
||||
result
|
||||
}
|
||||
|
||||
implicit def LongToPoint(x: Long): SDIndex =
|
||||
SDIndex.point(x)
|
||||
}
|
||||
|
|
|
@ -219,16 +219,19 @@ class MathTest extends FlatSpec with Matchers {
|
|||
|
||||
val x = sd.bind(arr)
|
||||
|
||||
println(x(SDIndex.point(0), ---).getArr)
|
||||
x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval
|
||||
|
||||
val data1 = x(SDIndex.interval(0: Long, 2: Long), ---).getArr
|
||||
println(data1)
|
||||
val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval
|
||||
val slice2 = x(0 :: 2, ---).eval
|
||||
slice1 shouldBe slice2
|
||||
}
|
||||
|
||||
val data2 = x(0 :: 2, ---).getArr
|
||||
println(data2)
|
||||
"SDVariable " should "be indexable in 3d" in {
|
||||
implicit val sd = SameDiff.create
|
||||
|
||||
//assert(indices.indices == List(0, 1, 2, 3, 4, 5))
|
||||
//assert(indices.targetShape.toList == List(2, 3))
|
||||
val arr = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 18).reshape(3, 3, 2)
|
||||
val x = sd.bind(arr)
|
||||
|
||||
println(x(0,0,SDIndex.all()).eval)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue