parent
7a1e327516
commit
b36b3fa1aa
|
@ -54,10 +54,21 @@ class SameDiffWrapper {
|
|||
sd.placeHolder(name, dataType, shape: _*)
|
||||
}
|
||||
|
||||
case class SDIndexWrapper(start: Long) {
|
||||
|
||||
def ::(end: Long): SDIndex =
|
||||
SDIndex.interval(start, end)
|
||||
}
|
||||
|
||||
object --- extends SDIndex {
|
||||
val thisIndex: SDIndex = SDIndex.all()
|
||||
}
|
||||
|
||||
class SDVariableWrapper {
|
||||
|
||||
var thisVariable: SDVariable = null
|
||||
var isScalar: Boolean = false
|
||||
val --- : SDIndex = SDIndex.all()
|
||||
|
||||
def this(variable: SDVariable) {
|
||||
this
|
||||
|
@ -68,13 +79,13 @@ class SDVariableWrapper {
|
|||
|
||||
def apply(index: SDIndex*): SDVariable = thisVariable.get(index: _*)
|
||||
|
||||
def apply(x: SDIndex)(y: SDIndex): SDVariable =
|
||||
/*def apply(x: SDIndex, y: SDIndex): SDVariable =
|
||||
(x, y) match {
|
||||
case (_, y) => thisVariable.get(SDIndex.all(), y)
|
||||
case (x, _) => thisVariable.get(x, SDIndex.all())
|
||||
case (_, _) => thisVariable.get(SDIndex.all(), SDIndex.all())
|
||||
case (x, y) => thisVariable.get(x, y)
|
||||
}
|
||||
}*/
|
||||
|
||||
def add(other: Double): Unit = thisVariable.add(other)
|
||||
|
||||
|
|
|
@ -15,9 +15,9 @@
|
|||
******************************************************************************/
|
||||
package org.nd4s.samediff.implicits
|
||||
|
||||
import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff }
|
||||
import org.nd4j.autodiff.samediff.{ SDIndex, SDVariable, SameDiff }
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import org.nd4s.samediff.{ SDVariableWrapper, SameDiffWrapper }
|
||||
import org.nd4s.samediff.{ SDIndexWrapper, SDVariableWrapper, SameDiffWrapper }
|
||||
|
||||
object Implicits {
|
||||
implicit def SameDiffToWrapper(sd: SameDiff): SameDiffWrapper =
|
||||
|
@ -43,4 +43,9 @@ object Implicits {
|
|||
result.isScalar = true
|
||||
result
|
||||
}
|
||||
|
||||
implicit def RangeToWrapper(start: Long): SDIndexWrapper = {
|
||||
val result = new SDIndexWrapper(start)
|
||||
result
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import org.nd4j.linalg.api.buffer.DataType
|
|||
import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import org.nd4s.Implicits._
|
||||
import org.nd4s.NDOrdering
|
||||
import org.nd4s.samediff.implicits.Implicits._
|
||||
import org.scalatest.{ FlatSpec, Matchers }
|
||||
|
||||
|
@ -214,9 +215,20 @@ class MathTest extends FlatSpec with Matchers {
|
|||
"SDVariable " should "be indexable in 2d" in {
|
||||
implicit val sd = SameDiff.create
|
||||
|
||||
val arr = Nd4j.rand(2, 5)
|
||||
val arr = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3, 3)
|
||||
|
||||
val x = sd.bind(arr)
|
||||
|
||||
println(x(SDIndex.point(0), _: SDIndex).getArr)
|
||||
println(x(SDIndex.point(0), ---).getArr)
|
||||
|
||||
val data1 = x(SDIndex.interval(0: Long, 2: Long), ---).getArr
|
||||
println(data1)
|
||||
|
||||
val data2 = x(0 :: 2, ---).getArr
|
||||
println(data2)
|
||||
|
||||
//assert(indices.indices == List(0, 1, 2, 3, 4, 5))
|
||||
//assert(indices.targetShape.toList == List(2, 3))
|
||||
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue