Some sugar

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
master
Alexander Stoyakin 2019-10-08 16:41:12 +03:00
parent 7a1e327516
commit b36b3fa1aa
3 changed files with 34 additions and 6 deletions

View File

@ -54,10 +54,21 @@ class SameDiffWrapper {
sd.placeHolder(name, dataType, shape: _*) sd.placeHolder(name, dataType, shape: _*)
} }
case class SDIndexWrapper(start: Long) {
def ::(end: Long): SDIndex =
SDIndex.interval(start, end)
}
object --- extends SDIndex {
val thisIndex: SDIndex = SDIndex.all()
}
class SDVariableWrapper { class SDVariableWrapper {
var thisVariable: SDVariable = null var thisVariable: SDVariable = null
var isScalar: Boolean = false var isScalar: Boolean = false
val --- : SDIndex = SDIndex.all()
def this(variable: SDVariable) { def this(variable: SDVariable) {
this this
@ -68,13 +79,13 @@ class SDVariableWrapper {
def apply(index: SDIndex*): SDVariable = thisVariable.get(index: _*) def apply(index: SDIndex*): SDVariable = thisVariable.get(index: _*)
def apply(x: SDIndex)(y: SDIndex): SDVariable = /*def apply(x: SDIndex, y: SDIndex): SDVariable =
(x, y) match { (x, y) match {
case (_, y) => thisVariable.get(SDIndex.all(), y) case (_, y) => thisVariable.get(SDIndex.all(), y)
case (x, _) => thisVariable.get(x, SDIndex.all()) case (x, _) => thisVariable.get(x, SDIndex.all())
case (_, _) => thisVariable.get(SDIndex.all(), SDIndex.all()) case (_, _) => thisVariable.get(SDIndex.all(), SDIndex.all())
case (x, y) => thisVariable.get(x, y) case (x, y) => thisVariable.get(x, y)
} }*/
def add(other: Double): Unit = thisVariable.add(other) def add(other: Double): Unit = thisVariable.add(other)

View File

@ -15,9 +15,9 @@
******************************************************************************/ ******************************************************************************/
package org.nd4s.samediff.implicits package org.nd4s.samediff.implicits
import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff } import org.nd4j.autodiff.samediff.{ SDIndex, SDVariable, SameDiff }
import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.factory.Nd4j
import org.nd4s.samediff.{ SDVariableWrapper, SameDiffWrapper } import org.nd4s.samediff.{ SDIndexWrapper, SDVariableWrapper, SameDiffWrapper }
object Implicits { object Implicits {
implicit def SameDiffToWrapper(sd: SameDiff): SameDiffWrapper = implicit def SameDiffToWrapper(sd: SameDiff): SameDiffWrapper =
@ -43,4 +43,9 @@ object Implicits {
result.isScalar = true result.isScalar = true
result result
} }
implicit def RangeToWrapper(start: Long): SDIndexWrapper = {
val result = new SDIndexWrapper(start)
result
}
} }

View File

@ -20,6 +20,7 @@ import org.nd4j.linalg.api.buffer.DataType
import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.factory.Nd4j
import org.nd4s.Implicits._ import org.nd4s.Implicits._
import org.nd4s.NDOrdering
import org.nd4s.samediff.implicits.Implicits._ import org.nd4s.samediff.implicits.Implicits._
import org.scalatest.{ FlatSpec, Matchers } import org.scalatest.{ FlatSpec, Matchers }
@ -214,9 +215,20 @@ class MathTest extends FlatSpec with Matchers {
"SDVariable " should "be indexable in 2d" in { "SDVariable " should "be indexable in 2d" in {
implicit val sd = SameDiff.create implicit val sd = SameDiff.create
val arr = Nd4j.rand(2, 5) val arr = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3, 3)
val x = sd.bind(arr) val x = sd.bind(arr)
println(x(SDIndex.point(0), _: SDIndex).getArr) println(x(SDIndex.point(0), ---).getArr)
val data1 = x(SDIndex.interval(0: Long, 2: Long), ---).getArr
println(data1)
val data2 = x(0 :: 2, ---).getArr
println(data2)
//assert(indices.indices == List(0, 1, 2, 3, 4, 5))
//assert(indices.targetShape.toList == List(2, 3))
} }
} }