commit
f515151f5f
|
@ -44,6 +44,9 @@ class SameDiffWrapper {
|
||||||
def bind(name: String, dataType: DataType, shape: Array[Long]): SDVariable =
|
def bind(name: String, dataType: DataType, shape: Array[Long]): SDVariable =
|
||||||
sd.`var`(name, dataType, shape: _*)
|
sd.`var`(name, dataType, shape: _*)
|
||||||
|
|
||||||
|
def bind(data: INDArray): SDVariable =
|
||||||
|
sd.`var`("", data)
|
||||||
|
|
||||||
def bind(name: String, dataType: DataType, shape: Array[Int]): SDVariable =
|
def bind(name: String, dataType: DataType, shape: Array[Int]): SDVariable =
|
||||||
sd.`var`(name, dataType, shape: _*)
|
sd.`var`(name, dataType, shape: _*)
|
||||||
|
|
||||||
|
@ -51,18 +54,33 @@ class SameDiffWrapper {
|
||||||
sd.placeHolder(name, dataType, shape: _*)
|
sd.placeHolder(name, dataType, shape: _*)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case class SDIndexWrapper(end: Long) {
|
||||||
|
|
||||||
|
def ::(start: Long): SDIndex =
|
||||||
|
SDIndex.interval(start, end)
|
||||||
|
}
|
||||||
|
|
||||||
|
object --- extends SDIndex {
|
||||||
|
val thisIndex: SDIndex = SDIndex.all()
|
||||||
|
}
|
||||||
|
|
||||||
class SDVariableWrapper {
|
class SDVariableWrapper {
|
||||||
|
|
||||||
var thisVariable: SDVariable = null
|
var thisVariable: SDVariable = null
|
||||||
var isScalar: Boolean = false
|
var isScalar: Boolean = false
|
||||||
|
val --- : SDIndex = SDIndex.all()
|
||||||
|
|
||||||
def this(variable: SDVariable) {
|
def this(variable: SDVariable) {
|
||||||
this
|
this
|
||||||
thisVariable = variable
|
thisVariable = variable
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Indexing
|
||||||
def apply(index: Long): SDVariable = thisVariable.get(SDIndex.point(index))
|
def apply(index: Long): SDVariable = thisVariable.get(SDIndex.point(index))
|
||||||
|
|
||||||
|
def apply(index: SDIndex*): SDVariable = thisVariable.get(index: _*)
|
||||||
|
|
||||||
|
// Arithmetic
|
||||||
def add(other: Double): Unit = thisVariable.add(other)
|
def add(other: Double): Unit = thisVariable.add(other)
|
||||||
|
|
||||||
def *(other: SDVariable): SDVariable =
|
def *(other: SDVariable): SDVariable =
|
||||||
|
|
|
@ -15,9 +15,9 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
package org.nd4s.samediff.implicits
|
package org.nd4s.samediff.implicits
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff }
|
import org.nd4j.autodiff.samediff.{ SDIndex, SDVariable, SameDiff }
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
import org.nd4s.samediff.{ SDVariableWrapper, SameDiffWrapper }
|
import org.nd4s.samediff.{ SDIndexWrapper, SDVariableWrapper, SameDiffWrapper }
|
||||||
|
|
||||||
object Implicits {
|
object Implicits {
|
||||||
implicit def SameDiffToWrapper(sd: SameDiff): SameDiffWrapper =
|
implicit def SameDiffToWrapper(sd: SameDiff): SameDiffWrapper =
|
||||||
|
@ -43,4 +43,12 @@ object Implicits {
|
||||||
result.isScalar = true
|
result.isScalar = true
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
implicit def RangeToWrapper(start: Long): SDIndexWrapper = {
|
||||||
|
val result = new SDIndexWrapper(start)
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
implicit def LongToPoint(x: Long): SDIndex =
|
||||||
|
SDIndex.point(x)
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ import org.nd4j.linalg.api.buffer.DataType
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
import org.nd4s.Implicits._
|
import org.nd4s.Implicits._
|
||||||
|
import org.nd4s.NDOrdering
|
||||||
import org.nd4s.samediff.implicits.Implicits._
|
import org.nd4s.samediff.implicits.Implicits._
|
||||||
import org.scalatest.{ FlatSpec, Matchers }
|
import org.scalatest.{ FlatSpec, Matchers }
|
||||||
|
|
||||||
|
@ -205,9 +206,41 @@ class MathTest extends FlatSpec with Matchers {
|
||||||
implicit val sd = SameDiff.create
|
implicit val sd = SameDiff.create
|
||||||
|
|
||||||
val arr = Nd4j.linspace(1, 100, 100).reshape('c', 10L, 10L)
|
val arr = Nd4j.linspace(1, 100, 100).reshape('c', 10L, 10L)
|
||||||
val x = sd.`var`(arr)
|
val x = sd.bind(arr)
|
||||||
val y = new SDVariableWrapper(x)
|
val y = new SDVariableWrapper(x)
|
||||||
|
|
||||||
x.get(SDIndex.point(0)).getArr shouldBe y(0).getArr
|
x.get(SDIndex.point(0)).getArr shouldBe y(0).getArr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
"SDVariable " should "be indexable in 2d" in {
|
||||||
|
implicit val sd = SameDiff.create
|
||||||
|
|
||||||
|
val arr = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3, 3)
|
||||||
|
|
||||||
|
val x = sd.bind(arr)
|
||||||
|
|
||||||
|
x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval
|
||||||
|
|
||||||
|
val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval
|
||||||
|
val slice2 = x(0 :: 2, ---).eval
|
||||||
|
slice1 shouldBe slice2
|
||||||
|
}
|
||||||
|
|
||||||
|
"SDVariable " should "be indexable in 3d" in {
|
||||||
|
implicit val sd = SameDiff.create
|
||||||
|
|
||||||
|
val arr = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 18).reshape(3, 3, 2)
|
||||||
|
val x = sd.bind(arr)
|
||||||
|
|
||||||
|
x.get(SDIndex.all(), SDIndex.all(), SDIndex.all()).eval shouldBe x(---, ---, ---).eval
|
||||||
|
x.get(SDIndex.point(0), SDIndex.all(), SDIndex.all()).eval shouldBe x(0, ---, ---).eval
|
||||||
|
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval
|
||||||
|
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval
|
||||||
|
|
||||||
|
x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 :: 2, 0, 0).eval
|
||||||
|
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 :: 2,
|
||||||
|
0 :: 1,
|
||||||
|
0 :: 2).eval
|
||||||
|
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 :: 2, 0 :: 1, ---).eval
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue