[WIP] nd4s - data types (#51)
* Fixed tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Added conversions for Long Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Added conversions for Long Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Added data types Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Failing test Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Types in conversions Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Added mixins for integer types Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Conversion of different types to scalar Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Tests added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Tests for arrays construction Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Construction tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Fixed slicing Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Add own Executioner implementation to nd4s Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Filter operation activated * Collection tests activated Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Types in operations Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Commented unused code Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Types in operations Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * String implicit conversion added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * String implicit conversion added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>master
parent
c969b724bb
commit
68b82f3856
|
@ -19,7 +19,7 @@ import org.nd4s.Implicits._
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
import org.nd4j.linalg.api.ops.Op
|
import org.nd4j.linalg.api.ops.Op
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
import org.nd4s.ops.{ BitFilterOps, FilterOps, MapOps }
|
import org.nd4s.ops.{ BitFilterOps, FilterOps, FunctionalOpExecutioner, MapOps }
|
||||||
|
|
||||||
import scalaxy.loops._
|
import scalaxy.loops._
|
||||||
import scala.language.postfixOps
|
import scala.language.postfixOps
|
||||||
|
@ -33,7 +33,7 @@ trait CollectionLikeNDArray[A <: INDArray] {
|
||||||
|
|
||||||
def filter(f: Double => Boolean)(implicit ev: NDArrayEvidence[A, _]): A = notCleanedUp { _ =>
|
def filter(f: Double => Boolean)(implicit ev: NDArrayEvidence[A, _]): A = notCleanedUp { _ =>
|
||||||
val shape = underlying.shape()
|
val shape = underlying.shape()
|
||||||
ev.reshape(Nd4j.getExecutioner
|
ev.reshape(FunctionalOpExecutioner.apply
|
||||||
.exec(FilterOps(ev.linearView(underlying), f): Op)
|
.exec(FilterOps(ev.linearView(underlying), f): Op)
|
||||||
.asInstanceOf[A],
|
.asInstanceOf[A],
|
||||||
shape.map(_.toInt): _*)
|
shape.map(_.toInt): _*)
|
||||||
|
@ -41,7 +41,7 @@ trait CollectionLikeNDArray[A <: INDArray] {
|
||||||
|
|
||||||
def filterBit(f: Double => Boolean)(implicit ev: NDArrayEvidence[A, _]): A = notCleanedUp { _ =>
|
def filterBit(f: Double => Boolean)(implicit ev: NDArrayEvidence[A, _]): A = notCleanedUp { _ =>
|
||||||
val shape = underlying.shape()
|
val shape = underlying.shape()
|
||||||
ev.reshape(Nd4j.getExecutioner
|
ev.reshape(FunctionalOpExecutioner.apply
|
||||||
.exec(BitFilterOps(ev.linearView(underlying), f): Op)
|
.exec(BitFilterOps(ev.linearView(underlying), f): Op)
|
||||||
.asInstanceOf[A],
|
.asInstanceOf[A],
|
||||||
shape.map(_.toInt): _*)
|
shape.map(_.toInt): _*)
|
||||||
|
@ -49,7 +49,7 @@ trait CollectionLikeNDArray[A <: INDArray] {
|
||||||
|
|
||||||
def map(f: Double => Double)(implicit ev: NDArrayEvidence[A, _]): A = notCleanedUp { _ =>
|
def map(f: Double => Double)(implicit ev: NDArrayEvidence[A, _]): A = notCleanedUp { _ =>
|
||||||
val shape = underlying.shape()
|
val shape = underlying.shape()
|
||||||
ev.reshape(Nd4j.getExecutioner
|
ev.reshape(FunctionalOpExecutioner.apply
|
||||||
.exec(MapOps(ev.linearView(underlying), f): Op)
|
.exec(MapOps(ev.linearView(underlying), f): Op)
|
||||||
.asInstanceOf[A],
|
.asInstanceOf[A],
|
||||||
shape.map(_.toInt): _*)
|
shape.map(_.toInt): _*)
|
||||||
|
|
|
@ -12,12 +12,14 @@
|
||||||
* under the License.
|
* under the License.
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
package org.nd4s
|
package org.nd4s
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
import org.nd4j.linalg.indexing.{ INDArrayIndex, NDArrayIndex }
|
import org.nd4j.linalg.indexing.{ INDArrayIndex, NDArrayIndex }
|
||||||
|
|
||||||
import scala.collection.breakOut
|
import scala.collection.breakOut
|
||||||
object Implicits {
|
object Implicits {
|
||||||
|
|
||||||
|
@ -34,7 +36,7 @@ object Implicits {
|
||||||
implicit def sliceProjection2NDArray(sliced: SliceProjectedNDArray): INDArray = sliced.array
|
implicit def sliceProjection2NDArray(sliced: SliceProjectedNDArray): INDArray = sliced.array
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Avoid using Numeric[T].toDouble(t:T) for sequence transformation in XXColl2INDArray to minimize memory consumption.
|
Avoid using Numeric[T].toDouble(t:T) for sequence transformation in XXColl2INDArray to minimize memory consumption.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
implicit def floatArray2INDArray(s: Array[Float]): FloatArray2INDArray =
|
implicit def floatArray2INDArray(s: Array[Float]): FloatArray2INDArray =
|
||||||
|
@ -48,7 +50,7 @@ object Implicits {
|
||||||
Nd4j.create(underlying, shape, ord.value, offset)
|
Nd4j.create(underlying, shape, ord.value, offset)
|
||||||
|
|
||||||
def asNDArray(shape: Int*): INDArray =
|
def asNDArray(shape: Int*): INDArray =
|
||||||
Nd4j.create(underlying, shape.toArray: _*)
|
Nd4j.create(underlying.toArray, shape.toArray: _*)
|
||||||
|
|
||||||
def toNDArray: INDArray = Nd4j.create(underlying)
|
def toNDArray: INDArray = Nd4j.create(underlying)
|
||||||
}
|
}
|
||||||
|
@ -64,7 +66,7 @@ object Implicits {
|
||||||
Nd4j.create(underlying, shape, offset, ord.value)
|
Nd4j.create(underlying, shape, offset, ord.value)
|
||||||
|
|
||||||
def asNDArray(shape: Int*): INDArray =
|
def asNDArray(shape: Int*): INDArray =
|
||||||
Nd4j.create(underlying, shape.toArray: _*)
|
Nd4j.create(underlying.toArray, shape.toArray: _*)
|
||||||
|
|
||||||
def toNDArray: INDArray = Nd4j.create(underlying)
|
def toNDArray: INDArray = Nd4j.create(underlying)
|
||||||
}
|
}
|
||||||
|
@ -76,54 +78,168 @@ object Implicits {
|
||||||
implicit def jintColl2INDArray(s: Seq[java.lang.Integer]): IntArray2INDArray =
|
implicit def jintColl2INDArray(s: Seq[java.lang.Integer]): IntArray2INDArray =
|
||||||
new IntArray2INDArray(s.map(x => x: Int)(breakOut))
|
new IntArray2INDArray(s.map(x => x: Int)(breakOut))
|
||||||
class IntArray2INDArray(val underlying: Array[Int]) extends AnyVal {
|
class IntArray2INDArray(val underlying: Array[Int]) extends AnyVal {
|
||||||
def mkNDArray(shape: Array[Int], ord: NDOrdering = NDOrdering(Nd4j.order()), offset: Int = 0): INDArray =
|
def mkNDArray(shape: Array[Int], ord: NDOrdering = NDOrdering(Nd4j.order()), offset: Int = 0): INDArray = {
|
||||||
Nd4j.create(underlying.map(_.toFloat), shape, ord.value, offset)
|
val strides = Nd4j.getStrides(shape, ord.value)
|
||||||
|
Nd4j.create(underlying, shape.map(_.toLong), strides.map(_.toLong), ord.value, DataType.INT)
|
||||||
|
}
|
||||||
|
|
||||||
def asNDArray(shape: Int*): INDArray =
|
def toNDArray: INDArray = Nd4j.createFromArray(underlying: _*)
|
||||||
Nd4j.create(underlying.map(_.toFloat), shape.toArray: _*)
|
|
||||||
|
|
||||||
def toNDArray: INDArray = Nd4j.create(underlying.map(_.toFloat).toArray)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
implicit class FloatMtrix2INDArray(val underlying: Seq[Seq[Float]]) extends AnyVal {
|
implicit def longColl2INDArray(s: Seq[Long]): LongArray2INDArray =
|
||||||
|
new LongArray2INDArray(s.toArray)
|
||||||
|
implicit def longArray2INDArray(s: Array[Long]): LongArray2INDArray =
|
||||||
|
new LongArray2INDArray(s)
|
||||||
|
implicit def jlongColl2INDArray(s: Seq[java.lang.Long]): LongArray2INDArray =
|
||||||
|
new LongArray2INDArray(s.map(x => x: Long)(breakOut))
|
||||||
|
class LongArray2INDArray(val underlying: Array[Long]) extends AnyVal {
|
||||||
|
def mkNDArray(shape: Array[Int], ord: NDOrdering = NDOrdering(Nd4j.order()), offset: Int = 0): INDArray = {
|
||||||
|
val strides = Nd4j.getStrides(shape, ord.value)
|
||||||
|
Nd4j.create(underlying, shape.map(_.toLong), strides.map(_.toLong), ord.value, DataType.LONG)
|
||||||
|
}
|
||||||
|
|
||||||
|
def toNDArray: INDArray = Nd4j.createFromArray(underlying: _*)
|
||||||
|
}
|
||||||
|
|
||||||
|
implicit def shortColl2INDArray(s: Seq[Short]): ShortArray2INDArray =
|
||||||
|
new ShortArray2INDArray(s.toArray)
|
||||||
|
implicit def shortArray2INDArray(s: Array[Short]): ShortArray2INDArray =
|
||||||
|
new ShortArray2INDArray(s)
|
||||||
|
implicit def jshortColl2INDArray(s: Seq[java.lang.Short]): ShortArray2INDArray =
|
||||||
|
new ShortArray2INDArray(s.map(x => x: Short)(breakOut))
|
||||||
|
class ShortArray2INDArray(val underlying: Array[Short]) extends AnyVal {
|
||||||
|
def mkNDArray(shape: Array[Int], ord: NDOrdering = NDOrdering(Nd4j.order()), offset: Int = 0): INDArray = {
|
||||||
|
val strides = Nd4j.getStrides(shape, ord.value)
|
||||||
|
Nd4j.create(underlying, shape.map(_.toLong), strides.map(_.toLong), ord.value, DataType.SHORT)
|
||||||
|
}
|
||||||
|
|
||||||
|
def toNDArray: INDArray = Nd4j.createFromArray(underlying: _*)
|
||||||
|
}
|
||||||
|
|
||||||
|
implicit def byteColl2INDArray(s: Seq[Byte]): ByteArray2INDArray =
|
||||||
|
new ByteArray2INDArray(s.toArray)
|
||||||
|
implicit def byteArray2INDArray(s: Array[Byte]): ByteArray2INDArray =
|
||||||
|
new ByteArray2INDArray(s)
|
||||||
|
implicit def jbyteColl2INDArray(s: Seq[java.lang.Byte]): ByteArray2INDArray =
|
||||||
|
new ByteArray2INDArray(s.map(x => x: Byte)(breakOut))
|
||||||
|
class ByteArray2INDArray(val underlying: Array[Byte]) extends AnyVal {
|
||||||
|
def mkNDArray(shape: Array[Int], ord: NDOrdering = NDOrdering(Nd4j.order()), offset: Int = 0): INDArray = {
|
||||||
|
val strides = Nd4j.getStrides(shape, ord.value)
|
||||||
|
Nd4j.create(underlying, shape.map(_.toLong), strides.map(_.toLong), ord.value, DataType.BYTE)
|
||||||
|
}
|
||||||
|
|
||||||
|
def toNDArray: INDArray = Nd4j.createFromArray(underlying: _*)
|
||||||
|
}
|
||||||
|
|
||||||
|
implicit def booleanColl2INDArray(s: Seq[Boolean]): BooleanArray2INDArray =
|
||||||
|
new BooleanArray2INDArray(s.toArray)
|
||||||
|
implicit def booleanArray2INDArray(s: Array[Boolean]): BooleanArray2INDArray =
|
||||||
|
new BooleanArray2INDArray(s)
|
||||||
|
implicit def jbooleanColl2INDArray(s: Seq[java.lang.Boolean]): BooleanArray2INDArray =
|
||||||
|
new BooleanArray2INDArray(s.map(x => x: Boolean)(breakOut))
|
||||||
|
class BooleanArray2INDArray(val underlying: Array[Boolean]) extends AnyVal {
|
||||||
|
def mkNDArray(shape: Array[Int], ord: NDOrdering = NDOrdering(Nd4j.order()), offset: Int = 0): INDArray = {
|
||||||
|
val strides = Nd4j.getStrides(shape, ord.value)
|
||||||
|
Nd4j.create(underlying, shape.map(_.toLong), strides.map(_.toLong), ord.value, DataType.BOOL)
|
||||||
|
}
|
||||||
|
|
||||||
|
def toNDArray: INDArray = Nd4j.createFromArray(underlying: _*)
|
||||||
|
}
|
||||||
|
|
||||||
|
implicit def stringArray2INDArray(s: Array[String]): StringArray2INDArray =
|
||||||
|
new StringArray2INDArray(s)
|
||||||
|
implicit def stringArray2CollArray(s: Seq[String]): StringArray2INDArray =
|
||||||
|
new StringArray2INDArray(s.toArray)
|
||||||
|
implicit def jstringColl2INDArray(s: Seq[java.lang.String]): StringArray2INDArray =
|
||||||
|
new StringArray2INDArray(s.map(x => x: String)(breakOut))
|
||||||
|
class StringArray2INDArray(val underlying: Array[String]) extends AnyVal {
|
||||||
|
def mkNDArray(shape: Array[Int], ord: NDOrdering = NDOrdering(Nd4j.order()), offset: Int = 0): INDArray = ???
|
||||||
|
|
||||||
|
def asNDArray(shape: Int*): INDArray = ???
|
||||||
|
|
||||||
|
def toNDArray: INDArray = Nd4j.create(underlying: _*)
|
||||||
|
}
|
||||||
|
|
||||||
|
implicit class FloatMatrix2INDArray(val underlying: Seq[Seq[Float]]) extends AnyVal {
|
||||||
def mkNDArray(ord: NDOrdering): INDArray =
|
def mkNDArray(ord: NDOrdering): INDArray =
|
||||||
Nd4j.create(underlying.map(_.toArray).toArray, ord.value)
|
Nd4j.create(underlying.map(_.toArray).toArray, ord.value)
|
||||||
def toNDArray: INDArray = Nd4j.create(underlying.map(_.toArray).toArray)
|
def toNDArray: INDArray = Nd4j.create(underlying.map(_.toArray).toArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
implicit class FloatArrayMtrix2INDArray(val underlying: Array[Array[Float]]) extends AnyVal {
|
implicit class FloatArrayMatrix2INDArray(val underlying: Array[Array[Float]]) extends AnyVal {
|
||||||
def mkNDArray(ord: NDOrdering): INDArray =
|
def mkNDArray(ord: NDOrdering): INDArray =
|
||||||
Nd4j.create(underlying, ord.value)
|
Nd4j.create(underlying, ord.value)
|
||||||
def toNDArray: INDArray = Nd4j.create(underlying)
|
def toNDArray: INDArray = Nd4j.create(underlying)
|
||||||
}
|
}
|
||||||
|
|
||||||
implicit class DoubleMtrix2INDArray(val underlying: Seq[Seq[Double]]) extends AnyVal {
|
implicit class DoubleMatrix2INDArray(val underlying: Seq[Seq[Double]]) extends AnyVal {
|
||||||
def mkNDArray(ord: NDOrdering): INDArray =
|
def mkNDArray(ord: NDOrdering): INDArray =
|
||||||
Nd4j.create(underlying.map(_.toArray).toArray, ord.value)
|
Nd4j.create(underlying.map(_.toArray).toArray, ord.value)
|
||||||
def toNDArray: INDArray = Nd4j.create(underlying.map(_.toArray).toArray)
|
def toNDArray: INDArray = Nd4j.create(underlying.map(_.toArray).toArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
implicit class DoubleArrayMtrix2INDArray(val underlying: Array[Array[Double]]) extends AnyVal {
|
implicit class DoubleArrayMatrix2INDArray(val underlying: Array[Array[Double]]) extends AnyVal {
|
||||||
def mkNDArray(ord: NDOrdering): INDArray =
|
def mkNDArray(ord: NDOrdering): INDArray =
|
||||||
Nd4j.create(underlying, ord.value)
|
Nd4j.create(underlying, ord.value)
|
||||||
def toNDArray: INDArray = Nd4j.create(underlying)
|
def toNDArray: INDArray = Nd4j.create(underlying)
|
||||||
}
|
}
|
||||||
|
|
||||||
implicit class IntMtrix2INDArray(val underlying: Seq[Seq[Int]]) extends AnyVal {
|
implicit class IntMatrix2INDArray(val underlying: Seq[Seq[Int]]) extends AnyVal {
|
||||||
def mkNDArray(ord: NDOrdering): INDArray =
|
def mkNDArray(ord: NDOrdering): INDArray =
|
||||||
Nd4j.create(underlying.map(_.map(_.toFloat).toArray).toArray, ord.value)
|
Nd4j.createFromArray(underlying.map(_.toArray).toArray)
|
||||||
def toNDArray: INDArray =
|
def toNDArray: INDArray =
|
||||||
Nd4j.create(underlying.map(_.map(_.toFloat).toArray).toArray)
|
Nd4j.createFromArray(underlying.map(_.toArray).toArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
implicit class IntArrayMtrix2INDArray(val underlying: Array[Array[Int]]) extends AnyVal {
|
implicit class IntArrayMatrix2INDArray(val underlying: Array[Array[Int]]) extends AnyVal {
|
||||||
def mkNDArray(ord: NDOrdering): INDArray =
|
def mkNDArray(ord: NDOrdering): INDArray =
|
||||||
Nd4j.create(underlying.map(_.map(_.toFloat)), ord.value)
|
Nd4j.createFromArray(underlying.map(_.toArray).toArray)
|
||||||
def toNDArray: INDArray = Nd4j.create(underlying.map(_.map(_.toFloat)))
|
def toNDArray: INDArray = Nd4j.createFromArray(underlying.map(_.toArray).toArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
implicit class Num2Scalar[T](val underlying: T)(implicit ev: Numeric[T]) {
|
implicit class LongMatrix2INDArray(val underlying: Seq[Seq[Long]]) extends AnyVal {
|
||||||
|
def mkNDArray(ord: NDOrdering): INDArray =
|
||||||
|
Nd4j.createFromArray(underlying.map(_.toArray).toArray)
|
||||||
|
def toNDArray: INDArray =
|
||||||
|
Nd4j.createFromArray(underlying.map(_.toArray).toArray)
|
||||||
|
}
|
||||||
|
|
||||||
|
implicit class LongArrayMatrix2INDArray(val underlying: Array[Array[Long]]) extends AnyVal {
|
||||||
|
def mkNDArray(ord: NDOrdering): INDArray =
|
||||||
|
Nd4j.createFromArray(underlying.map(_.toArray).toArray)
|
||||||
|
def toNDArray: INDArray = Nd4j.createFromArray(underlying.map(_.toArray).toArray)
|
||||||
|
}
|
||||||
|
|
||||||
|
/*implicit class Num2Scalar[T](val underlying: T)(implicit ev: Numeric[T]) {
|
||||||
def toScalar: INDArray = Nd4j.scalar(ev.toDouble(underlying))
|
def toScalar: INDArray = Nd4j.scalar(ev.toDouble(underlying))
|
||||||
|
}*/
|
||||||
|
|
||||||
|
implicit class Float2Scalar(val underlying: Float) {
|
||||||
|
def toScalar: INDArray = Nd4j.scalar(underlying)
|
||||||
|
}
|
||||||
|
|
||||||
|
implicit class Double2Scalar(val underlying: Double) {
|
||||||
|
def toScalar: INDArray = Nd4j.scalar(underlying)
|
||||||
|
}
|
||||||
|
|
||||||
|
implicit class Long2Scalar(val underlying: Long) {
|
||||||
|
def toScalar: INDArray = Nd4j.scalar(underlying)
|
||||||
|
}
|
||||||
|
|
||||||
|
implicit class Int2Scalar(val underlying: Int) {
|
||||||
|
def toScalar: INDArray = Nd4j.scalar(underlying)
|
||||||
|
}
|
||||||
|
|
||||||
|
implicit class Byte2Scalar(val underlying: Byte) {
|
||||||
|
def toScalar: INDArray = Nd4j.scalar(underlying)
|
||||||
|
}
|
||||||
|
|
||||||
|
implicit class Boolean2Scalar(val underlying: Boolean) {
|
||||||
|
def toScalar: INDArray = Nd4j.scalar(underlying)
|
||||||
|
}
|
||||||
|
|
||||||
|
implicit class String2Scalar(val underlying: String) {
|
||||||
|
def toScalar: INDArray = Nd4j.scalar(underlying)
|
||||||
}
|
}
|
||||||
|
|
||||||
implicit def intArray2IndexRangeArray(arr: Array[Int]): Array[IndexRange] =
|
implicit def intArray2IndexRangeArray(arr: Array[Int]): Array[IndexRange] =
|
||||||
|
|
|
@ -22,6 +22,9 @@ import org.nd4s.Implicits._
|
||||||
object Evidences {
|
object Evidences {
|
||||||
implicit val double = DoubleNDArrayEvidence
|
implicit val double = DoubleNDArrayEvidence
|
||||||
implicit val float = FloatNDArrayEvidence
|
implicit val float = FloatNDArrayEvidence
|
||||||
|
implicit val int = IntNDArrayEvidence
|
||||||
|
implicit val long = LongNDArrayEvidence
|
||||||
|
implicit val byte = ByteNDArrayEvidence
|
||||||
}
|
}
|
||||||
|
|
||||||
object NDArrayEvidence {
|
object NDArrayEvidence {
|
||||||
|
@ -342,3 +345,201 @@ case object FloatNDArrayEvidence extends RealNDArrayEvidence[Float] {
|
||||||
|
|
||||||
override def lessThan(left: Float, right: Float): Boolean = left < right
|
override def lessThan(left: Float, right: Float): Boolean = left < right
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trait IntegerNDArrayEvidence[Value] extends NDArrayEvidence[INDArray, Value] {
|
||||||
|
def remainder(a: INDArray, that: INDArray): INDArray = a.remainder(that)
|
||||||
|
|
||||||
|
def add(a: INDArray, that: INDArray): INDArray = a.add(that)
|
||||||
|
|
||||||
|
def sub(a: INDArray, that: INDArray): INDArray = a.sub(that)
|
||||||
|
|
||||||
|
def mul(a: INDArray, that: INDArray): INDArray = a.mul(that)
|
||||||
|
|
||||||
|
def mmul(a: INDArray, that: INDArray): INDArray = a.mmul(that)
|
||||||
|
|
||||||
|
def div(a: INDArray, that: INDArray): INDArray = a.div(that)
|
||||||
|
|
||||||
|
def rdiv(a: INDArray, that: INDArray): INDArray = a.rdiv(that)
|
||||||
|
|
||||||
|
def addi(a: INDArray, that: INDArray): INDArray = a.addi(that)
|
||||||
|
|
||||||
|
def subi(a: INDArray, that: INDArray): INDArray = a.subi(that)
|
||||||
|
|
||||||
|
def muli(a: INDArray, that: INDArray): INDArray = a.muli(that)
|
||||||
|
|
||||||
|
def mmuli(a: INDArray, that: INDArray): INDArray = a.mmuli(that)
|
||||||
|
|
||||||
|
def remainderi(a: INDArray, that: INDArray): INDArray = a.remainder(that)
|
||||||
|
|
||||||
|
def remainderi(a: INDArray, that: Number): INDArray = a.remainderi(that)
|
||||||
|
|
||||||
|
def divi(a: INDArray, that: INDArray): INDArray = a.divi(that)
|
||||||
|
|
||||||
|
def rdivi(a: INDArray, that: INDArray): INDArray = a.rdivi(that)
|
||||||
|
|
||||||
|
def remainder(a: INDArray, that: Number): INDArray = a.remainder(that)
|
||||||
|
|
||||||
|
def add(a: INDArray, that: Number): INDArray = a.add(that)
|
||||||
|
|
||||||
|
def sub(a: INDArray, that: Number): INDArray = a.sub(that)
|
||||||
|
|
||||||
|
def mul(a: INDArray, that: Number): INDArray = a.mul(that)
|
||||||
|
|
||||||
|
def div(a: INDArray, that: Number): INDArray = a.div(that)
|
||||||
|
|
||||||
|
def rdiv(a: INDArray, that: Number): INDArray = a.rdiv(that)
|
||||||
|
|
||||||
|
def addi(a: INDArray, that: Number): INDArray = a.addi(that)
|
||||||
|
|
||||||
|
def subi(a: INDArray, that: Number): INDArray = a.subi(that)
|
||||||
|
|
||||||
|
def muli(a: INDArray, that: Number): INDArray = a.muli(that)
|
||||||
|
|
||||||
|
def divi(a: INDArray, that: Number): INDArray = a.divi(that)
|
||||||
|
|
||||||
|
def rdivi(a: INDArray, that: Number): INDArray = a.rdivi(that)
|
||||||
|
|
||||||
|
def put(a: INDArray, i: Int, element: INDArray): INDArray = a.put(i, element)
|
||||||
|
|
||||||
|
def put(a: INDArray, i: Array[Int], element: INDArray): INDArray = a.put(i, element)
|
||||||
|
|
||||||
|
def get(a: INDArray, i: INDArrayIndex*): INDArray = a.get(i: _*)
|
||||||
|
|
||||||
|
def reshape(a: INDArray, i: Int*): INDArray = a.reshape(i.map(_.toLong): _*)
|
||||||
|
|
||||||
|
def linearView(a: INDArray): INDArray = a.reshape(-1)
|
||||||
|
|
||||||
|
def dup(a: INDArray): INDArray = a.dup()
|
||||||
|
|
||||||
|
def update(a: INDArray, indices: Array[IndexRange], i: Int): INDArray =
|
||||||
|
a.update(indices, i)
|
||||||
|
|
||||||
|
def update(a: INDArray, indices: Array[IndexRange], i: INDArray): INDArray =
|
||||||
|
a.update(indices, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
case object IntNDArrayEvidence extends IntegerNDArrayEvidence[Int] {
|
||||||
|
|
||||||
|
def sum(ndarray: INDArray): Int = ndarray.sumNumber().intValue()
|
||||||
|
|
||||||
|
def mean(ndarray: INDArray): Int = ndarray.meanNumber().intValue()
|
||||||
|
|
||||||
|
def normMax(ndarray: INDArray): Int = ndarray.normmaxNumber().intValue()
|
||||||
|
|
||||||
|
def norm1(ndarray: INDArray): Int = ndarray.norm1Number().intValue()
|
||||||
|
|
||||||
|
def norm2(ndarray: INDArray): Int = ndarray.norm2Number().intValue()
|
||||||
|
|
||||||
|
def max(ndarray: INDArray): Int = ndarray.maxNumber().intValue()
|
||||||
|
|
||||||
|
def min(ndarray: INDArray): Int = ndarray.minNumber().intValue()
|
||||||
|
|
||||||
|
def standardDeviation(ndarray: INDArray): Int = ndarray.stdNumber().intValue()
|
||||||
|
|
||||||
|
def product(ndarray: INDArray): Int = ndarray.prodNumber().intValue()
|
||||||
|
|
||||||
|
def variance(ndarray: INDArray): Int = ndarray.varNumber().intValue()
|
||||||
|
|
||||||
|
def get(a: INDArray, i: Int): Int = a.getInt(i)
|
||||||
|
|
||||||
|
def get(a: INDArray, i: Int, j: Int): Int = a.getInt(i, j)
|
||||||
|
|
||||||
|
def get(a: INDArray, i: Int*): Int = a.getInt(i: _*)
|
||||||
|
|
||||||
|
def create(arr: Array[Int]): INDArray = arr.toNDArray
|
||||||
|
|
||||||
|
def create(arr: Array[Int], shape: Int*): INDArray = arr.toNDArray
|
||||||
|
|
||||||
|
def create(arr: Array[Int], shape: Array[Int], ordering: NDOrdering, offset: Int): INDArray =
|
||||||
|
arr.mkNDArray(shape, ordering, offset)
|
||||||
|
|
||||||
|
def greaterThan(left: Int, right: Int): Boolean = left > right
|
||||||
|
|
||||||
|
def lessThan(left: Int, right: Int): Boolean = left < right
|
||||||
|
}
|
||||||
|
|
||||||
|
case object LongNDArrayEvidence extends IntegerNDArrayEvidence[Long] {
|
||||||
|
|
||||||
|
def sum(ndarray: INDArray): Long = ndarray.sumNumber().longValue()
|
||||||
|
|
||||||
|
def mean(ndarray: INDArray): Long = ndarray.meanNumber().longValue()
|
||||||
|
|
||||||
|
def normMax(ndarray: INDArray): Long = ndarray.normmaxNumber().longValue()
|
||||||
|
|
||||||
|
def norm1(ndarray: INDArray): Long = ndarray.norm1Number().longValue()
|
||||||
|
|
||||||
|
def norm2(ndarray: INDArray): Long = ndarray.norm2Number().longValue()
|
||||||
|
|
||||||
|
def max(ndarray: INDArray): Long = ndarray.maxNumber().longValue()
|
||||||
|
|
||||||
|
def min(ndarray: INDArray): Long = ndarray.minNumber().longValue()
|
||||||
|
|
||||||
|
def standardDeviation(ndarray: INDArray): Long = ndarray.stdNumber().longValue()
|
||||||
|
|
||||||
|
def product(ndarray: INDArray): Long = ndarray.prodNumber().longValue()
|
||||||
|
|
||||||
|
def variance(ndarray: INDArray): Long = ndarray.varNumber().longValue()
|
||||||
|
|
||||||
|
def get(a: INDArray, i: Int): Long = a.getLong(i)
|
||||||
|
|
||||||
|
def get(a: INDArray, i: Int, j: Int): Long = a.getLong(i, j)
|
||||||
|
|
||||||
|
def get(a: INDArray, i: Int*): Long = a.getLong(i.map(_.toLong): _*)
|
||||||
|
|
||||||
|
def create(arr: Array[Long]): INDArray = arr.toNDArray
|
||||||
|
|
||||||
|
def create(arr: Array[Long], shape: Int*): INDArray = arr.toNDArray
|
||||||
|
|
||||||
|
def create(arr: Array[Long], shape: Array[Int], ordering: NDOrdering, offset: Int): INDArray =
|
||||||
|
arr.mkNDArray(shape, ordering, offset)
|
||||||
|
|
||||||
|
def greaterThan(left: Long, right: Long): Boolean = left > right
|
||||||
|
|
||||||
|
def lessThan(left: Long, right: Long): Boolean = left < right
|
||||||
|
|
||||||
|
def update(a: INDArray, indices: Array[IndexRange], i: Long): INDArray =
|
||||||
|
a.update(indices, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
case object ByteNDArrayEvidence extends IntegerNDArrayEvidence[Byte] {
|
||||||
|
|
||||||
|
def sum(ndarray: INDArray): Byte = ndarray.sumNumber().byteValue()
|
||||||
|
|
||||||
|
def mean(ndarray: INDArray): Byte = ndarray.meanNumber().byteValue()
|
||||||
|
|
||||||
|
def normMax(ndarray: INDArray): Byte = ndarray.normmaxNumber().byteValue()
|
||||||
|
|
||||||
|
def norm1(ndarray: INDArray): Byte = ndarray.norm1Number().byteValue()
|
||||||
|
|
||||||
|
def norm2(ndarray: INDArray): Byte = ndarray.norm2Number().byteValue()
|
||||||
|
|
||||||
|
def max(ndarray: INDArray): Byte = ndarray.maxNumber().byteValue()
|
||||||
|
|
||||||
|
def min(ndarray: INDArray): Byte = ndarray.minNumber().byteValue()
|
||||||
|
|
||||||
|
def standardDeviation(ndarray: INDArray): Byte = ndarray.stdNumber().byteValue()
|
||||||
|
|
||||||
|
def product(ndarray: INDArray): Byte = ndarray.prodNumber().byteValue()
|
||||||
|
|
||||||
|
def variance(ndarray: INDArray): Byte = ndarray.varNumber().byteValue()
|
||||||
|
|
||||||
|
def get(a: INDArray, i: Int): Byte = a.getInt(i).toByte
|
||||||
|
|
||||||
|
def get(a: INDArray, i: Int, j: Int): Byte = a.getInt(i, j).toByte
|
||||||
|
|
||||||
|
def get(a: INDArray, i: Int*): Byte = a.getInt(i.map(_.toInt): _*).toByte
|
||||||
|
|
||||||
|
def create(arr: Array[Byte]): INDArray = arr.toNDArray
|
||||||
|
|
||||||
|
def create(arr: Array[Byte], shape: Int*): INDArray = arr.toNDArray
|
||||||
|
|
||||||
|
def create(arr: Array[Byte], shape: Array[Int], ordering: NDOrdering, offset: Int): INDArray =
|
||||||
|
arr.mkNDArray(shape, ordering, offset)
|
||||||
|
|
||||||
|
def greaterThan(left: Byte, right: Byte): Boolean = left > right
|
||||||
|
|
||||||
|
def lessThan(left: Byte, right: Byte): Boolean = left < right
|
||||||
|
|
||||||
|
def update(a: INDArray, indices: Array[IndexRange], i: Byte): INDArray =
|
||||||
|
a.update(indices, i)
|
||||||
|
}
|
||||||
|
|
|
@ -48,6 +48,7 @@ trait SliceableNDArray[A <: INDArray] {
|
||||||
ev.create(filtered, targetShape, NDOrdering(underlying.ordering()), 0)
|
ev.create(filtered, targetShape, NDOrdering(underlying.ordering()), 0)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
ev.get(underlying, getINDArrayIndexfrom(target: _*): _*)
|
ev.get(underlying, getINDArrayIndexfrom(target: _*): _*)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -121,10 +122,10 @@ trait SliceableNDArray[A <: INDArray] {
|
||||||
underlying.shape().map(_.toInt).toList
|
underlying.shape().map(_.toInt).toList
|
||||||
|
|
||||||
val originalTarget =
|
val originalTarget =
|
||||||
if (underlying.isRowVector && target.size == 1)
|
/*if (underlying.isRowVector && target.size == 1)
|
||||||
IntRange(0) +: target
|
IntRange(0) +: target
|
||||||
else
|
else*/
|
||||||
target
|
target
|
||||||
|
|
||||||
@tailrec
|
@tailrec
|
||||||
def modifyTargetIndices(input: List[IndexRange], i: Int, acc: List[INDArrayIndex]): List[INDArrayIndex] =
|
def modifyTargetIndices(input: List[IndexRange], i: Int, acc: List[INDArrayIndex]): List[INDArrayIndex] =
|
||||||
|
|
|
@ -55,4 +55,11 @@ class BitFilterOps(_x: INDArray, len: Int, f: Double => Boolean)
|
||||||
|
|
||||||
override def op(origin: Float): Float = if (f(origin)) 1 else 0
|
override def op(origin: Float): Float = if (f(origin)) 1 else 0
|
||||||
|
|
||||||
|
override def op(origin: Short): Short = if (f(origin)) 1 else 0
|
||||||
|
|
||||||
|
override def op(origin: Int): Int = if (f(origin)) 1 else 0
|
||||||
|
|
||||||
|
override def op(origin: Long): Long = if (f(origin)) 1 else 0
|
||||||
|
|
||||||
|
override def op(origin: String): String = ???
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,4 +54,12 @@ class FilterOps(_x: INDArray, len: Int, f: Double => Boolean)
|
||||||
|
|
||||||
override def op(origin: Float): Float = if (f(origin)) origin else 0
|
override def op(origin: Float): Float = if (f(origin)) origin else 0
|
||||||
|
|
||||||
|
override def op(origin: Short): Short = if (f(origin)) origin else 0
|
||||||
|
|
||||||
|
override def op(origin: Int): Int = if (f(origin)) origin else 0
|
||||||
|
|
||||||
|
override def op(origin: Long): Long = if (f(origin)) origin else 0
|
||||||
|
|
||||||
|
override def op(origin: String): String = ???
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,474 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4s.ops
|
||||||
|
|
||||||
|
import java.util.{ List, Map, Properties }
|
||||||
|
|
||||||
|
import org.bytedeco.javacpp.Pointer
|
||||||
|
import org.nd4j.linalg.api.buffer.{ DataBuffer, DataType, Utf8Buffer }
|
||||||
|
import org.nd4j.linalg.api.ndarray.{ INDArray, INDArrayStatistics }
|
||||||
|
import org.nd4j.linalg.api.ops.aggregates.{ Aggregate, Batch }
|
||||||
|
import org.nd4j.linalg.api.ops._
|
||||||
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner
|
||||||
|
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate
|
||||||
|
import org.nd4j.linalg.api.ops.impl.summarystats.Variance
|
||||||
|
import org.nd4j.linalg.api.rng.Random
|
||||||
|
import org.nd4j.linalg.api.shape.{ LongShapeDescriptor, TadPack }
|
||||||
|
import org.nd4j.linalg.cache.TADManager
|
||||||
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
|
import org.nd4j.linalg.profiler.ProfilerConfig
|
||||||
|
|
||||||
|
object FunctionalOpExecutioner {
|
||||||
|
def apply: FunctionalOpExecutioner = new FunctionalOpExecutioner()
|
||||||
|
}
|
||||||
|
class FunctionalOpExecutioner extends OpExecutioner {
|
||||||
|
def isVerbose: Boolean = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns true if debug mode is enabled, false otherwise
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
def isDebug: Boolean = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns type for this executioner instance
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
def `type`: OpExecutioner.ExecutionerType = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns opName of the last invoked op
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
def getLastOp: String = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Execute the operation
|
||||||
|
*
|
||||||
|
* @param op the operation to execute
|
||||||
|
*/
|
||||||
|
def exec(op: Op): INDArray =
|
||||||
|
op match {
|
||||||
|
case op: FilterOps => exec(op.asInstanceOf[FilterOps])
|
||||||
|
case op: BitFilterOps => exec(op.asInstanceOf[BitFilterOps])
|
||||||
|
case op: MapOps => exec(op.asInstanceOf[MapOps])
|
||||||
|
case _ => op.z()
|
||||||
|
}
|
||||||
|
|
||||||
|
def exec(op: FilterOps): INDArray = {
|
||||||
|
val retVal: INDArray = Nd4j.create(op.x.dataType(), op.x.shape().map(_.toLong): _*)
|
||||||
|
for (i <- 0 until op.x().length().toInt) {
|
||||||
|
val filtered = op.x.dataType() match {
|
||||||
|
case DataType.DOUBLE => op.op(op.x.getDouble(i.toLong))
|
||||||
|
case DataType.FLOAT => op.op(op.x.getFloat(i.toLong))
|
||||||
|
case DataType.INT => op.op(op.x.getInt(i))
|
||||||
|
case DataType.SHORT => op.op(op.x.getInt(i))
|
||||||
|
case (DataType.LONG) => op.op(op.x.getLong(i.toLong))
|
||||||
|
}
|
||||||
|
retVal.putScalar(i, filtered)
|
||||||
|
}
|
||||||
|
retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
def exec(op: BitFilterOps): INDArray = {
|
||||||
|
val retVal: INDArray = Nd4j.create(op.x.dataType(), op.x.shape().map(_.toLong): _*)
|
||||||
|
for (i <- 0 until op.x().length().toInt) {
|
||||||
|
val current = if (op.x.dataType() == DataType.DOUBLE) op.x().getDouble(i.toLong) else op.x().getInt(i)
|
||||||
|
val filtered = op.op(current)
|
||||||
|
|
||||||
|
retVal.putScalar(i, filtered)
|
||||||
|
}
|
||||||
|
retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
def exec(op: MapOps): INDArray = {
|
||||||
|
val retVal: INDArray = Nd4j.create(op.x.dataType(), op.x.shape().map(_.toLong): _*)
|
||||||
|
for (i <- 0 until op.x().length().toInt) {
|
||||||
|
val current = if (op.x.dataType() == DataType.DOUBLE) op.x().getDouble(i.toLong) else op.x().getInt(i)
|
||||||
|
val filtered = op.op(current)
|
||||||
|
|
||||||
|
retVal.putScalar(i, filtered)
|
||||||
|
}
|
||||||
|
retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Execute a TransformOp and return the result
|
||||||
|
*
|
||||||
|
* @param op the operation to execute
|
||||||
|
*/
|
||||||
|
def execAndReturn(op: TransformOp): TransformOp = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Execute and return the result from an accumulation
|
||||||
|
*
|
||||||
|
* @param op the operation to execute
|
||||||
|
* @return the accumulated result
|
||||||
|
*/
|
||||||
|
def execAndReturn(op: ReduceOp): ReduceOp = ???
|
||||||
|
|
||||||
|
def execAndReturn(op: Variance): Variance = ???
|
||||||
|
|
||||||
|
/** Execute and return the result from an index accumulation
|
||||||
|
*
|
||||||
|
* @param op the index accumulation operation to execute
|
||||||
|
* @return the accumulated index
|
||||||
|
*/
|
||||||
|
def execAndReturn(op: IndexAccumulation): IndexAccumulation = ???
|
||||||
|
|
||||||
|
/** Execute and return the result from a scalar op
|
||||||
|
*
|
||||||
|
* @param op the operation to execute
|
||||||
|
* @return the accumulated result
|
||||||
|
*/
|
||||||
|
def execAndReturn(op: ScalarOp): ScalarOp = ???
|
||||||
|
|
||||||
|
/** Execute and return the result from a vector op
|
||||||
|
*
|
||||||
|
* @param op */
|
||||||
|
def execAndReturn(op: BroadcastOp): BroadcastOp = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Execute a reduceOp, possibly along one or more dimensions
|
||||||
|
*
|
||||||
|
* @param reduceOp the reduceOp
|
||||||
|
* @return the reduceOp op
|
||||||
|
*/
|
||||||
|
def exec(reduceOp: ReduceOp): INDArray = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Execute a broadcast op, possibly along one or more dimensions
|
||||||
|
*
|
||||||
|
* @param broadcast the accumulation
|
||||||
|
* @return the broadcast op
|
||||||
|
*/
|
||||||
|
def exec(broadcast: BroadcastOp): INDArray = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Execute ScalarOp
|
||||||
|
*
|
||||||
|
* @param broadcast
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
def exec(broadcast: ScalarOp): INDArray = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Execute an variance accumulation op, possibly along one or more dimensions
|
||||||
|
*
|
||||||
|
* @param accumulation the accumulation
|
||||||
|
* @return the accmulation op
|
||||||
|
*/
|
||||||
|
def exec(accumulation: Variance): INDArray = ???
|
||||||
|
|
||||||
|
/** Execute an index accumulation along one or more dimensions
|
||||||
|
*
|
||||||
|
* @param indexAccum the index accumulation operation
|
||||||
|
* @return result
|
||||||
|
*/
|
||||||
|
def exec(indexAccum: IndexAccumulation): INDArray = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* Execute and return a result
|
||||||
|
* ndarray from the given op
|
||||||
|
*
|
||||||
|
* @param op the operation to execute
|
||||||
|
* @return the result from the operation
|
||||||
|
*/
|
||||||
|
def execAndReturn(op: Op): Op = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Execute MetaOp
|
||||||
|
*
|
||||||
|
* @param op
|
||||||
|
*/
|
||||||
|
def exec(op: MetaOp): Unit = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Execute GridOp
|
||||||
|
*
|
||||||
|
* @param op
|
||||||
|
*/
|
||||||
|
def exec(op: GridOp): Unit = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param op
|
||||||
|
*/
|
||||||
|
def exec(op: Aggregate): Unit = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method executes previously built batch
|
||||||
|
*
|
||||||
|
* @param batch
|
||||||
|
*/
|
||||||
|
def exec[T <: Aggregate](batch: Batch[T]): Unit = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method takes arbitrary sized list of aggregates,
|
||||||
|
* and packs them into batches
|
||||||
|
*
|
||||||
|
* @param batch
|
||||||
|
*/
|
||||||
|
def exec(batch: java.util.List[Aggregate]): Unit = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method executes specified RandomOp using default RNG available via Nd4j.getRandom()
|
||||||
|
*
|
||||||
|
* @param op
|
||||||
|
*/
|
||||||
|
def exec(op: RandomOp): INDArray = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method executes specific RandomOp against specified RNG
|
||||||
|
*
|
||||||
|
* @param op
|
||||||
|
* @param rng
|
||||||
|
*/
|
||||||
|
def exec(op: RandomOp, rng: Random): INDArray = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method return set of key/value and
|
||||||
|
* key/key/value objects,
|
||||||
|
* describing current environment
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
def getEnvironmentInformation: Properties = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method specifies desired profiling mode
|
||||||
|
*
|
||||||
|
* @param mode
|
||||||
|
*/
|
||||||
|
@deprecated def setProfilingMode(mode: OpExecutioner.ProfilingMode): Unit = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method stores specified configuration.
|
||||||
|
*
|
||||||
|
* @param config
|
||||||
|
*/
|
||||||
|
def setProfilingConfig(config: ProfilerConfig): Unit = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Ths method returns current profiling
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@deprecated def getProfilingMode: OpExecutioner.ProfilingMode = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns TADManager instance used for this OpExecutioner
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
def getTADManager: TADManager = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method prints out environmental information returned by getEnvironmentInformation() method
|
||||||
|
*/
|
||||||
|
def printEnvironmentInformation(): Unit = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method ensures all operations that supposed to be executed at this moment, are executed.
|
||||||
|
*/
|
||||||
|
def push(): Unit = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method ensures all operations that supposed to be executed at this moment, are executed and finished.
|
||||||
|
*/
|
||||||
|
def commit(): Unit = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method encodes array as thresholds, updating input array at the same time
|
||||||
|
*
|
||||||
|
* @param input
|
||||||
|
* @return encoded array is returned
|
||||||
|
*/
|
||||||
|
def thresholdEncode(input: INDArray, threshold: Double): INDArray = ???
|
||||||
|
|
||||||
|
def thresholdEncode(input: INDArray, threshold: Double, boundary: Integer): INDArray = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method decodes thresholds array, and puts it into target array
|
||||||
|
*
|
||||||
|
* @param encoded
|
||||||
|
* @param target
|
||||||
|
* @return target is returned
|
||||||
|
*/
|
||||||
|
def thresholdDecode(encoded: INDArray, target: INDArray): INDArray = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns number of elements affected by encoder
|
||||||
|
*
|
||||||
|
* @param indArray
|
||||||
|
* @param target
|
||||||
|
* @param threshold
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
def bitmapEncode(indArray: INDArray, target: INDArray, threshold: Double): Long = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param indArray
|
||||||
|
* @param threshold
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
def bitmapEncode(indArray: INDArray, threshold: Double): INDArray = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param encoded
|
||||||
|
* @param target
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
def bitmapDecode(encoded: INDArray, target: INDArray): INDArray = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns names of all custom operations available in current backend, and their number of input/output arguments
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
def getCustomOperations: java.util.Map[String, CustomOpDescriptor] = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method executes given CustomOp
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: You're responsible for input/output validation
|
||||||
|
*
|
||||||
|
* @param op
|
||||||
|
*/
|
||||||
|
def execAndReturn(op: CustomOp): CustomOp = ???
|
||||||
|
|
||||||
|
def exec(op: CustomOp): Array[INDArray] = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method executes op with given context
|
||||||
|
*
|
||||||
|
* @param op
|
||||||
|
* @param context
|
||||||
|
* @return method returns output arrays defined within context
|
||||||
|
*/
|
||||||
|
def exec(op: CustomOp, context: OpContext): Array[INDArray] = ???
|
||||||
|
|
||||||
|
def calculateOutputShape(op: CustomOp): java.util.List[LongShapeDescriptor] = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Equivalent to calli
|
||||||
|
*/
|
||||||
|
def allocateOutputArrays(op: CustomOp): Array[INDArray] = ???
|
||||||
|
|
||||||
|
def enableDebugMode(reallyEnable: Boolean): Unit = ???
|
||||||
|
|
||||||
|
def enableVerboseMode(reallyEnable: Boolean): Unit = ???
|
||||||
|
|
||||||
|
def isExperimentalMode: Boolean = ???
|
||||||
|
|
||||||
|
def registerGraph(id: Long, graph: Pointer): Unit = ???
|
||||||
|
|
||||||
|
def executeGraph(id: Long,
|
||||||
|
map: java.util.Map[String, INDArray],
|
||||||
|
reverseMap: java.util.Map[String, Integer]): java.util.Map[String, INDArray] = ???
|
||||||
|
|
||||||
|
def forgetGraph(id: Long): Unit = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method allows to set desired number of elements per thread, for performance optimization purposes.
|
||||||
|
* I.e. if array contains 2048 elements, and threshold is set to 1024, 2 threads will be used for given op execution.
|
||||||
|
*
|
||||||
|
* Default value: 1024
|
||||||
|
*
|
||||||
|
* @param threshold
|
||||||
|
*/
|
||||||
|
def setElementsThreshold(threshold: Int): Unit = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method allows to set desired number of sub-arrays per thread, for performance optimization purposes.
|
||||||
|
* I.e. if matrix has shape of 64 x 128, and threshold is set to 8, each thread will be processing 8 sub-arrays (sure, if you have 8 core cpu).
|
||||||
|
* If your cpu has, say, 4, cores, only 4 threads will be spawned, and each will process 16 sub-arrays
|
||||||
|
*
|
||||||
|
* Default value: 8
|
||||||
|
*
|
||||||
|
* @param threshold
|
||||||
|
*/
|
||||||
|
def setTadThreshold(threshold: Int): Unit = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method extracts String from Utf8Buffer
|
||||||
|
*
|
||||||
|
* @param buffer
|
||||||
|
* @param index
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
def getString(buffer: Utf8Buffer, index: Long): String = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns OpContext which can be used (and reused) to execute custom ops
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
def buildContext: OpContext = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param array
|
||||||
|
*/
|
||||||
|
def inspectArray(array: INDArray): INDArrayStatistics = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns shapeInfo DataBuffer
|
||||||
|
*
|
||||||
|
* @param shape
|
||||||
|
* @param stride
|
||||||
|
* @param elementWiseStride
|
||||||
|
* @param order
|
||||||
|
* @param dtype
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
def createShapeInfo(shape: Array[Long],
|
||||||
|
stride: Array[Long],
|
||||||
|
elementWiseStride: Long,
|
||||||
|
order: Char,
|
||||||
|
dtype: DataType,
|
||||||
|
empty: Boolean): DataBuffer = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns host/device tad buffers
|
||||||
|
*/
|
||||||
|
def tadShapeInfoAndOffsets(array: INDArray, dimension: Array[Int]): TadPack = ???
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns constant buffer for the given jvm array
|
||||||
|
*
|
||||||
|
* @param values
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
def createConstantBuffer(values: Array[Long], desiredType: DataType): DataBuffer = ???
|
||||||
|
|
||||||
|
def createConstantBuffer(values: Array[Int], desiredType: DataType): DataBuffer = ???
|
||||||
|
|
||||||
|
def createConstantBuffer(values: Array[Float], desiredType: DataType): DataBuffer = ???
|
||||||
|
|
||||||
|
def createConstantBuffer(values: Array[Double], desiredType: DataType): DataBuffer = ???
|
||||||
|
|
||||||
|
@deprecated def scatterUpdate(op: ScatterUpdate.UpdateOp,
|
||||||
|
array: INDArray,
|
||||||
|
indices: INDArray,
|
||||||
|
updates: INDArray,
|
||||||
|
axis: Array[Int]): Unit = ???
|
||||||
|
}
|
|
@ -34,5 +34,13 @@ trait LeftAssociativeBinaryOp {
|
||||||
|
|
||||||
def op(origin: Float): Float
|
def op(origin: Float): Float
|
||||||
|
|
||||||
|
def op(origin: Short): Short
|
||||||
|
|
||||||
|
def op(origin: Int): Int
|
||||||
|
|
||||||
|
def op(origin: Long): Long
|
||||||
|
|
||||||
|
def op(origin: String): String
|
||||||
|
|
||||||
// def op(origin: IComplexNumber): IComplexNumber
|
// def op(origin: IComplexNumber): IComplexNumber
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,4 +49,11 @@ class MapOps(_x: INDArray, f: Double => Double) extends BaseScalarOp(_x, null, _
|
||||||
|
|
||||||
override def op(origin: Float): Float = f(origin).toFloat
|
override def op(origin: Float): Float = f(origin).toFloat
|
||||||
|
|
||||||
|
override def op(origin: Short): Short = f(origin).toShort
|
||||||
|
|
||||||
|
override def op(origin: Int): Int = f(origin).toInt
|
||||||
|
|
||||||
|
override def op(origin: Long): Long = f(origin).toLong
|
||||||
|
|
||||||
|
override def op(origin: String): String = ???
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,7 @@ class DSLSpec extends FlatSpec with Matchers {
|
||||||
|
|
||||||
// This test just verifies that an INDArray gets wrapped with an implicit conversion
|
// This test just verifies that an INDArray gets wrapped with an implicit conversion
|
||||||
|
|
||||||
val nd = Nd4j.create(Array[Float](1, 2), Array(2, 1))
|
val nd = Nd4j.create(Array[Float](1, 2), Array(2, 1): _*)
|
||||||
val nd1 = nd + 10L // + creates new array, += modifies in place
|
val nd1 = nd + 10L // + creates new array, += modifies in place
|
||||||
|
|
||||||
nd.get(0) should equal(1)
|
nd.get(0) should equal(1)
|
||||||
|
|
|
@ -19,7 +19,7 @@ import org.nd4s.Implicits._
|
||||||
import org.scalatest.{ FlatSpec, Matchers }
|
import org.scalatest.{ FlatSpec, Matchers }
|
||||||
|
|
||||||
class NDArrayCollectionAPITest extends FlatSpec with Matchers {
|
class NDArrayCollectionAPITest extends FlatSpec with Matchers {
|
||||||
"CollectionLikeNDArray" should "provides filter API" ignore {
|
"CollectionLikeNDArray" should "provides filter API" in {
|
||||||
val ndArray =
|
val ndArray =
|
||||||
Array(
|
Array(
|
||||||
Array(1, 2, 3),
|
Array(1, 2, 3),
|
||||||
|
@ -38,7 +38,48 @@ class NDArrayCollectionAPITest extends FlatSpec with Matchers {
|
||||||
).toNDArray
|
).toNDArray
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
it should "provides filter bitmask API" ignore {
|
|
||||||
|
"CollectionLikeNDArray from Floats" should "provides filter API" in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1f, 2f, 3f),
|
||||||
|
Array(4f, 5f, 6f),
|
||||||
|
Array(7f, 8f, 9f)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
val filtered = ndArray.filter(_ > 3)
|
||||||
|
|
||||||
|
assert(
|
||||||
|
filtered ==
|
||||||
|
Array(
|
||||||
|
Array(0f, 0f, 0f),
|
||||||
|
Array(4f, 5f, 6f),
|
||||||
|
Array(7f, 8f, 9f)
|
||||||
|
).toNDArray
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
"CollectionLikeNDArray from Long " should "provides filter API" in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1L, 2L, 3L),
|
||||||
|
Array(4L, 5L, 6L),
|
||||||
|
Array(7L, 8L, 9L)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
val filtered = ndArray.filter(_ > 3)
|
||||||
|
|
||||||
|
assert(
|
||||||
|
filtered ==
|
||||||
|
Array(
|
||||||
|
Array(0L, 0L, 0L),
|
||||||
|
Array(4L, 5L, 6L),
|
||||||
|
Array(7L, 8L, 9L)
|
||||||
|
).toNDArray
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "provides filter bitmask API" in {
|
||||||
val ndArray =
|
val ndArray =
|
||||||
Array(
|
Array(
|
||||||
Array(1, 2, 3),
|
Array(1, 2, 3),
|
||||||
|
@ -57,7 +98,7 @@ class NDArrayCollectionAPITest extends FlatSpec with Matchers {
|
||||||
).toNDArray
|
).toNDArray
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
it should "provides map API" ignore {
|
it should "provides map API" in {
|
||||||
val ndArray =
|
val ndArray =
|
||||||
Array(
|
Array(
|
||||||
Array(1, 2, 3),
|
Array(1, 2, 3),
|
||||||
|
|
|
@ -0,0 +1,121 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4s
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType
|
||||||
|
import org.nd4s.Implicits._
|
||||||
|
import org.scalatest.FlatSpec
|
||||||
|
|
||||||
|
class NDArrayConstructionTest extends FlatSpec with COrderingForTest {
|
||||||
|
self: OrderingForTest =>
|
||||||
|
|
||||||
|
it should "be able to create 2d matrix filled with integers" in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1, 2),
|
||||||
|
Array(4, 5),
|
||||||
|
Array(7, 9)
|
||||||
|
).mkNDArray(ordering)
|
||||||
|
|
||||||
|
assert(DataType.INT == ndArray.dataType())
|
||||||
|
assert(3 == ndArray.rows())
|
||||||
|
assert(2 == ndArray.columns())
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "be able to create 2d matrix filled with long integers" in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1L, 2L, 3L),
|
||||||
|
Array(4L, 5L, 6L),
|
||||||
|
Array(7L, 8L, 9L)
|
||||||
|
).mkNDArray(ordering)
|
||||||
|
|
||||||
|
assert(DataType.LONG == ndArray.dataType())
|
||||||
|
assert(3 == ndArray.rows())
|
||||||
|
assert(3 == ndArray.columns())
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "be able to create 2d matrix filled with float numbers" in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1f, 2f, 3f),
|
||||||
|
Array(4f, 5f, 6f),
|
||||||
|
Array(7f, 8f, 9f)
|
||||||
|
).mkNDArray(ordering)
|
||||||
|
|
||||||
|
assert(DataType.FLOAT == ndArray.dataType())
|
||||||
|
assert(3 == ndArray.rows())
|
||||||
|
assert(3 == ndArray.columns())
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "be able to create 2d matrix filled with double numbers" in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1d, 2d, 3d),
|
||||||
|
Array(4d, 5d, 6d),
|
||||||
|
Array(7d, 8d, 9d)
|
||||||
|
).mkNDArray(ordering)
|
||||||
|
|
||||||
|
assert(DataType.DOUBLE == ndArray.dataType())
|
||||||
|
assert(3 == ndArray.rows())
|
||||||
|
assert(3 == ndArray.columns())
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "be able to create vector filled with short integers" in {
|
||||||
|
val ndArray = Array[Short](1, 2, 3).toNDArray
|
||||||
|
|
||||||
|
assert(DataType.SHORT == ndArray.dataType())
|
||||||
|
assert(1 == ndArray.rows())
|
||||||
|
assert(3 == ndArray.columns())
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "be able to create vector filled with byte values" in {
|
||||||
|
val ndArray = Array[Byte](1, 2, 3).toNDArray
|
||||||
|
|
||||||
|
assert(DataType.BYTE == ndArray.dataType())
|
||||||
|
assert(1 == ndArray.rows())
|
||||||
|
assert(3 == ndArray.columns())
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "be able to create vector filled with boolean values" in {
|
||||||
|
val ndArray = Array(true, false, true).toNDArray
|
||||||
|
|
||||||
|
assert(DataType.BOOL == ndArray.dataType())
|
||||||
|
assert(1 == ndArray.rows())
|
||||||
|
assert(3 == ndArray.columns())
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "be able to create vector from integer range" in {
|
||||||
|
val list = (0 to 9).toNDArray
|
||||||
|
assert(DataType.INT == list.dataType())
|
||||||
|
|
||||||
|
val stepped = list(1 -> 7 by 2)
|
||||||
|
assert(Array(1, 3, 5).toNDArray == stepped)
|
||||||
|
assert(DataType.INT == list.dataType())
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "be able to create vector from strings" in {
|
||||||
|
val oneString = "testme".toScalar
|
||||||
|
assert("testme" == oneString.getString(0))
|
||||||
|
assert(DataType.UTF8 == oneString.dataType())
|
||||||
|
|
||||||
|
val someStrings = Array[String]("one", "two", "three").toNDArray
|
||||||
|
assert("one" == someStrings.getString(0))
|
||||||
|
assert("two" == someStrings.getString(1))
|
||||||
|
assert("three" == someStrings.getString(2))
|
||||||
|
assert(DataType.UTF8 == someStrings.dataType())
|
||||||
|
}
|
||||||
|
}
|
|
@ -48,20 +48,48 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
|
||||||
assert(extracted == expected)
|
assert(extracted == expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
it should "be able to extract a part of 2d matrix with offset" ignore { //Ignored AB 2019/05/21 - https://github.com/deeplearning4j/deeplearning4j/issues/7657
|
it should "be able to extract a part of 2d matrix with double data and offset" in {
|
||||||
val ndArray = (1 to 9).mkNDArray(Array(2, 2), NDOrdering.C, offset = 4)
|
val ndArray = (1 to 9).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C, offset = 4)
|
||||||
|
|
||||||
val expectedArray = Array(
|
val expectedArray = Array(
|
||||||
Array(5, 6),
|
Array(5d, 6d),
|
||||||
Array(7, 8)
|
Array(7d, 8d)
|
||||||
).mkNDArray(ordering)
|
).mkNDArray(ordering)
|
||||||
assert(ndArray == expectedArray)
|
assert(ndArray == expectedArray)
|
||||||
|
|
||||||
val expectedSlice = Array(
|
val expectedSlice = Array(
|
||||||
Array(5),
|
Array(5d),
|
||||||
Array(7)
|
Array(7d)
|
||||||
).toNDArray
|
).toNDArray
|
||||||
assert(ndArray(->, 0) == expectedSlice)
|
assert(expectedArray(->, 0 -> 1) == expectedSlice)
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "be able to extract a part of 2d matrix with integer data" in {
|
||||||
|
val ndArray = (1 to 9).mkNDArray(Array(2, 2))
|
||||||
|
|
||||||
|
val expectedArray = Array(
|
||||||
|
Array(1, 2),
|
||||||
|
Array(3, 4)
|
||||||
|
).mkNDArray(ordering)
|
||||||
|
assert(ndArray == expectedArray)
|
||||||
|
|
||||||
|
val expectedSlice = Array(
|
||||||
|
Array(1),
|
||||||
|
Array(3)
|
||||||
|
).toNDArray
|
||||||
|
val actualSlice = expectedArray(->, 0 -> 1)
|
||||||
|
assert(actualSlice == expectedSlice)
|
||||||
|
}
|
||||||
|
|
||||||
|
it should " provide overloaded -> operator providing matrix slices as nd4j" in {
|
||||||
|
|
||||||
|
val expectedArray = (1 to 9).mkNDArray(Array(2, 2))
|
||||||
|
val expectedSlice = expectedArray.slice(0)
|
||||||
|
val actualSlice = expectedArray(0, ->)
|
||||||
|
|
||||||
|
Console.println(expectedSlice)
|
||||||
|
|
||||||
|
assert(actualSlice == expectedSlice)
|
||||||
}
|
}
|
||||||
|
|
||||||
it should "be able to extract a part of vertically long matrix in" in {
|
it should "be able to extract a part of vertically long matrix in" in {
|
||||||
|
@ -163,7 +191,7 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
|
||||||
|
|
||||||
// TODO: fix me. This is about INDArray having to be sliced by LONG indices
|
// TODO: fix me. This is about INDArray having to be sliced by LONG indices
|
||||||
// can't find the correct way to fix implicits without breaking other stuff.
|
// can't find the correct way to fix implicits without breaking other stuff.
|
||||||
it should "be able to extract sub-matrix with index range by step" ignore {
|
it should "be able to extract sub-matrix with index range by step" in {
|
||||||
val ndArray =
|
val ndArray =
|
||||||
Array(
|
Array(
|
||||||
Array(1, 2, 3),
|
Array(1, 2, 3),
|
||||||
|
@ -249,9 +277,10 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
"num2Scalar" should "convert number to Scalar INDArray" ignore { //Ignored AB 2019/05/21 - https://github.com/deeplearning4j/deeplearning4j/issues/7657
|
"num2Scalar" should "convert number to Scalar INDArray" in {
|
||||||
assert(1.toScalar == List(1).toNDArray)
|
|
||||||
assert(2f.toScalar == List(2).toNDArray)
|
assert(1.toScalar.data() == List(1).toNDArray.data())
|
||||||
assert(3d.toScalar == List(3).toNDArray)
|
assert(2f.toScalar.data() == List(2).toNDArray.data())
|
||||||
|
assert(3d.toScalar.data() == List(3).toNDArray.data())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,27 +26,27 @@ import org.scalatest.{ FlatSpec, Matchers }
|
||||||
class OperatableNDArrayTest extends FlatSpec with Matchers {
|
class OperatableNDArrayTest extends FlatSpec with Matchers {
|
||||||
"RichNDArray" should "use the apply method to access values" in {
|
"RichNDArray" should "use the apply method to access values" in {
|
||||||
// -- 2D array
|
// -- 2D array
|
||||||
val nd2 = Nd4j.create(Array[Double](1, 2, 3, 4), Array(4, 1))
|
val nd2 = Nd4j.create(Array[Double](1, 2, 3, 4), Array[Int](1, 4): _*)
|
||||||
|
|
||||||
nd2.get(0) should be(1)
|
nd2.get(0) should be(1)
|
||||||
nd2.get(3, 0) should be(4)
|
nd2.get(0, 3) should be(4)
|
||||||
|
|
||||||
// -- 3D array
|
// -- 3D array
|
||||||
val nd3 = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2))
|
val nd3 = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array[Int](2, 2, 2): _*)
|
||||||
nd3.get(0, 0, 0) should be(1)
|
nd3.get(0, 0, 0) should be(1)
|
||||||
nd3.get(1, 1, 1) should be(8)
|
nd3.get(1, 1, 1) should be(8)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
it should "use transpose abbreviation" in {
|
it should "use transpose abbreviation" in {
|
||||||
val nd1 = Nd4j.create(Array[Double](1, 2, 3), Array(3, 1))
|
val nd1 = Nd4j.create(Array[Double](1, 2, 3), Array(3, 1): _*)
|
||||||
nd1.shape should equal(Array(3, 1))
|
nd1.shape should equal(Array(3, 1))
|
||||||
val nd1t = nd1.T
|
val nd1t = nd1.T
|
||||||
nd1t.shape should equal(Array(1, 3))
|
nd1t.shape should equal(Array(1, 3))
|
||||||
}
|
}
|
||||||
|
|
||||||
it should "add correctly" in {
|
it should "add correctly" in {
|
||||||
val a = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2))
|
val a = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2): _*)
|
||||||
val b = a + 100
|
val b = a + 100
|
||||||
a.get(0, 0, 0) should be(1)
|
a.get(0, 0, 0) should be(1)
|
||||||
b.get(0, 0, 0) should be(101)
|
b.get(0, 0, 0) should be(101)
|
||||||
|
@ -55,7 +55,7 @@ class OperatableNDArrayTest extends FlatSpec with Matchers {
|
||||||
}
|
}
|
||||||
|
|
||||||
it should "subtract correctly" in {
|
it should "subtract correctly" in {
|
||||||
val a = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2))
|
val a = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2): _*)
|
||||||
val b = a - 100
|
val b = a - 100
|
||||||
a.get(0, 0, 0) should be(1)
|
a.get(0, 0, 0) should be(1)
|
||||||
b.get(0, 0, 0) should be(-99)
|
b.get(0, 0, 0) should be(-99)
|
||||||
|
@ -69,7 +69,7 @@ class OperatableNDArrayTest extends FlatSpec with Matchers {
|
||||||
}
|
}
|
||||||
|
|
||||||
it should "divide correctly" in {
|
it should "divide correctly" in {
|
||||||
val a = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2))
|
val a = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2): _*)
|
||||||
val b = a / a
|
val b = a / a
|
||||||
a.get(1, 1, 1) should be(8)
|
a.get(1, 1, 1) should be(8)
|
||||||
b.get(1, 1, 1) should be(1)
|
b.get(1, 1, 1) should be(1)
|
||||||
|
@ -78,7 +78,7 @@ class OperatableNDArrayTest extends FlatSpec with Matchers {
|
||||||
}
|
}
|
||||||
|
|
||||||
it should "element-by-element multiply correctly" in {
|
it should "element-by-element multiply correctly" in {
|
||||||
val a = Nd4j.create(Array[Double](1, 2, 3, 4), Array(4, 1))
|
val a = Nd4j.create(Array[Double](1, 2, 3, 4), Array(4, 1): _*)
|
||||||
val b = a * a
|
val b = a * a
|
||||||
a.get(3) should be(4) // [1.0, 2.0, 3.0, 4.0
|
a.get(3) should be(4) // [1.0, 2.0, 3.0, 4.0
|
||||||
b.get(3) should be(16) // [1.0 ,4.0 ,9.0 ,16.0]
|
b.get(3) should be(16) // [1.0 ,4.0 ,9.0 ,16.0]
|
||||||
|
@ -87,7 +87,7 @@ class OperatableNDArrayTest extends FlatSpec with Matchers {
|
||||||
}
|
}
|
||||||
|
|
||||||
it should "use the update method to mutate values" in {
|
it should "use the update method to mutate values" in {
|
||||||
val nd3 = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2))
|
val nd3 = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2): _*)
|
||||||
nd3(0) = 11
|
nd3(0) = 11
|
||||||
nd3.get(0) should be(11)
|
nd3.get(0) should be(11)
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ class OperatableNDArrayTest extends FlatSpec with Matchers {
|
||||||
val s: String = Nd4j.create(2, 2) + ""
|
val s: String = Nd4j.create(2, 2) + ""
|
||||||
}
|
}
|
||||||
|
|
||||||
"Sum function" should "choose return value depending on INDArray type" ignore {
|
"Sum function" should "choose return value depending on INDArray type" in {
|
||||||
val ndArray =
|
val ndArray =
|
||||||
Array(
|
Array(
|
||||||
Array(1, 2),
|
Array(1, 2),
|
||||||
|
|
Loading…
Reference in New Issue