[WIP] nd4s - data types (#51)

* Fixed tests

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Added conversions for Long

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Added conversions for Long

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Added data types

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Failing test

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Types in conversions

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Added mixins for integer types

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Conversion of different types to scalar

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Tests added

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Tests for arrays construction

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Construction tests

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Fixed slicing

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Add own Executioner implementation to nd4s

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Filter operation activated

* Collection tests activated

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Types in operations

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Commented unused code

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Types in operations

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* String implicit conversion added

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* String implicit conversion added

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
master
Alexander Stoyakin 2019-07-12 18:18:39 +03:00 committed by AlexDBlack
parent c969b724bb
commit 68b82f3856
14 changed files with 1067 additions and 54 deletions

View File

@ -19,7 +19,7 @@ import org.nd4s.Implicits._
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.ops.Op
import org.nd4j.linalg.factory.Nd4j
import org.nd4s.ops.{ BitFilterOps, FilterOps, MapOps }
import org.nd4s.ops.{ BitFilterOps, FilterOps, FunctionalOpExecutioner, MapOps }
import scalaxy.loops._
import scala.language.postfixOps
@ -33,7 +33,7 @@ trait CollectionLikeNDArray[A <: INDArray] {
def filter(f: Double => Boolean)(implicit ev: NDArrayEvidence[A, _]): A = notCleanedUp { _ =>
val shape = underlying.shape()
ev.reshape(Nd4j.getExecutioner
ev.reshape(FunctionalOpExecutioner.apply
.exec(FilterOps(ev.linearView(underlying), f): Op)
.asInstanceOf[A],
shape.map(_.toInt): _*)
@ -41,7 +41,7 @@ trait CollectionLikeNDArray[A <: INDArray] {
def filterBit(f: Double => Boolean)(implicit ev: NDArrayEvidence[A, _]): A = notCleanedUp { _ =>
val shape = underlying.shape()
ev.reshape(Nd4j.getExecutioner
ev.reshape(FunctionalOpExecutioner.apply
.exec(BitFilterOps(ev.linearView(underlying), f): Op)
.asInstanceOf[A],
shape.map(_.toInt): _*)
@ -49,7 +49,7 @@ trait CollectionLikeNDArray[A <: INDArray] {
def map(f: Double => Double)(implicit ev: NDArrayEvidence[A, _]): A = notCleanedUp { _ =>
val shape = underlying.shape()
ev.reshape(Nd4j.getExecutioner
ev.reshape(FunctionalOpExecutioner.apply
.exec(MapOps(ev.linearView(underlying), f): Op)
.asInstanceOf[A],
shape.map(_.toInt): _*)

View File

@ -15,9 +15,11 @@
******************************************************************************/
package org.nd4s
import org.nd4j.linalg.api.buffer.DataType
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.indexing.{ INDArrayIndex, NDArrayIndex }
import scala.collection.breakOut
object Implicits {
@ -48,7 +50,7 @@ object Implicits {
Nd4j.create(underlying, shape, ord.value, offset)
def asNDArray(shape: Int*): INDArray =
Nd4j.create(underlying, shape.toArray: _*)
Nd4j.create(underlying.toArray, shape.toArray: _*)
def toNDArray: INDArray = Nd4j.create(underlying)
}
@ -64,7 +66,7 @@ object Implicits {
Nd4j.create(underlying, shape, offset, ord.value)
def asNDArray(shape: Int*): INDArray =
Nd4j.create(underlying, shape.toArray: _*)
Nd4j.create(underlying.toArray, shape.toArray: _*)
def toNDArray: INDArray = Nd4j.create(underlying)
}
@ -76,54 +78,168 @@ object Implicits {
implicit def jintColl2INDArray(s: Seq[java.lang.Integer]): IntArray2INDArray =
new IntArray2INDArray(s.map(x => x: Int)(breakOut))
class IntArray2INDArray(val underlying: Array[Int]) extends AnyVal {
def mkNDArray(shape: Array[Int], ord: NDOrdering = NDOrdering(Nd4j.order()), offset: Int = 0): INDArray =
Nd4j.create(underlying.map(_.toFloat), shape, ord.value, offset)
def asNDArray(shape: Int*): INDArray =
Nd4j.create(underlying.map(_.toFloat), shape.toArray: _*)
def toNDArray: INDArray = Nd4j.create(underlying.map(_.toFloat).toArray)
def mkNDArray(shape: Array[Int], ord: NDOrdering = NDOrdering(Nd4j.order()), offset: Int = 0): INDArray = {
val strides = Nd4j.getStrides(shape, ord.value)
Nd4j.create(underlying, shape.map(_.toLong), strides.map(_.toLong), ord.value, DataType.INT)
}
implicit class FloatMtrix2INDArray(val underlying: Seq[Seq[Float]]) extends AnyVal {
def toNDArray: INDArray = Nd4j.createFromArray(underlying: _*)
}
implicit def longColl2INDArray(s: Seq[Long]): LongArray2INDArray =
new LongArray2INDArray(s.toArray)
implicit def longArray2INDArray(s: Array[Long]): LongArray2INDArray =
new LongArray2INDArray(s)
implicit def jlongColl2INDArray(s: Seq[java.lang.Long]): LongArray2INDArray =
new LongArray2INDArray(s.map(x => x: Long)(breakOut))
class LongArray2INDArray(val underlying: Array[Long]) extends AnyVal {
def mkNDArray(shape: Array[Int], ord: NDOrdering = NDOrdering(Nd4j.order()), offset: Int = 0): INDArray = {
val strides = Nd4j.getStrides(shape, ord.value)
Nd4j.create(underlying, shape.map(_.toLong), strides.map(_.toLong), ord.value, DataType.LONG)
}
def toNDArray: INDArray = Nd4j.createFromArray(underlying: _*)
}
implicit def shortColl2INDArray(s: Seq[Short]): ShortArray2INDArray =
new ShortArray2INDArray(s.toArray)
implicit def shortArray2INDArray(s: Array[Short]): ShortArray2INDArray =
new ShortArray2INDArray(s)
implicit def jshortColl2INDArray(s: Seq[java.lang.Short]): ShortArray2INDArray =
new ShortArray2INDArray(s.map(x => x: Short)(breakOut))
class ShortArray2INDArray(val underlying: Array[Short]) extends AnyVal {
def mkNDArray(shape: Array[Int], ord: NDOrdering = NDOrdering(Nd4j.order()), offset: Int = 0): INDArray = {
val strides = Nd4j.getStrides(shape, ord.value)
Nd4j.create(underlying, shape.map(_.toLong), strides.map(_.toLong), ord.value, DataType.SHORT)
}
def toNDArray: INDArray = Nd4j.createFromArray(underlying: _*)
}
implicit def byteColl2INDArray(s: Seq[Byte]): ByteArray2INDArray =
new ByteArray2INDArray(s.toArray)
implicit def byteArray2INDArray(s: Array[Byte]): ByteArray2INDArray =
new ByteArray2INDArray(s)
implicit def jbyteColl2INDArray(s: Seq[java.lang.Byte]): ByteArray2INDArray =
new ByteArray2INDArray(s.map(x => x: Byte)(breakOut))
class ByteArray2INDArray(val underlying: Array[Byte]) extends AnyVal {
def mkNDArray(shape: Array[Int], ord: NDOrdering = NDOrdering(Nd4j.order()), offset: Int = 0): INDArray = {
val strides = Nd4j.getStrides(shape, ord.value)
Nd4j.create(underlying, shape.map(_.toLong), strides.map(_.toLong), ord.value, DataType.BYTE)
}
def toNDArray: INDArray = Nd4j.createFromArray(underlying: _*)
}
implicit def booleanColl2INDArray(s: Seq[Boolean]): BooleanArray2INDArray =
new BooleanArray2INDArray(s.toArray)
implicit def booleanArray2INDArray(s: Array[Boolean]): BooleanArray2INDArray =
new BooleanArray2INDArray(s)
implicit def jbooleanColl2INDArray(s: Seq[java.lang.Boolean]): BooleanArray2INDArray =
new BooleanArray2INDArray(s.map(x => x: Boolean)(breakOut))
class BooleanArray2INDArray(val underlying: Array[Boolean]) extends AnyVal {
def mkNDArray(shape: Array[Int], ord: NDOrdering = NDOrdering(Nd4j.order()), offset: Int = 0): INDArray = {
val strides = Nd4j.getStrides(shape, ord.value)
Nd4j.create(underlying, shape.map(_.toLong), strides.map(_.toLong), ord.value, DataType.BOOL)
}
def toNDArray: INDArray = Nd4j.createFromArray(underlying: _*)
}
implicit def stringArray2INDArray(s: Array[String]): StringArray2INDArray =
new StringArray2INDArray(s)
implicit def stringArray2CollArray(s: Seq[String]): StringArray2INDArray =
new StringArray2INDArray(s.toArray)
implicit def jstringColl2INDArray(s: Seq[java.lang.String]): StringArray2INDArray =
new StringArray2INDArray(s.map(x => x: String)(breakOut))
class StringArray2INDArray(val underlying: Array[String]) extends AnyVal {
def mkNDArray(shape: Array[Int], ord: NDOrdering = NDOrdering(Nd4j.order()), offset: Int = 0): INDArray = ???
def asNDArray(shape: Int*): INDArray = ???
def toNDArray: INDArray = Nd4j.create(underlying: _*)
}
implicit class FloatMatrix2INDArray(val underlying: Seq[Seq[Float]]) extends AnyVal {
def mkNDArray(ord: NDOrdering): INDArray =
Nd4j.create(underlying.map(_.toArray).toArray, ord.value)
def toNDArray: INDArray = Nd4j.create(underlying.map(_.toArray).toArray)
}
implicit class FloatArrayMtrix2INDArray(val underlying: Array[Array[Float]]) extends AnyVal {
implicit class FloatArrayMatrix2INDArray(val underlying: Array[Array[Float]]) extends AnyVal {
def mkNDArray(ord: NDOrdering): INDArray =
Nd4j.create(underlying, ord.value)
def toNDArray: INDArray = Nd4j.create(underlying)
}
implicit class DoubleMtrix2INDArray(val underlying: Seq[Seq[Double]]) extends AnyVal {
implicit class DoubleMatrix2INDArray(val underlying: Seq[Seq[Double]]) extends AnyVal {
def mkNDArray(ord: NDOrdering): INDArray =
Nd4j.create(underlying.map(_.toArray).toArray, ord.value)
def toNDArray: INDArray = Nd4j.create(underlying.map(_.toArray).toArray)
}
implicit class DoubleArrayMtrix2INDArray(val underlying: Array[Array[Double]]) extends AnyVal {
implicit class DoubleArrayMatrix2INDArray(val underlying: Array[Array[Double]]) extends AnyVal {
def mkNDArray(ord: NDOrdering): INDArray =
Nd4j.create(underlying, ord.value)
def toNDArray: INDArray = Nd4j.create(underlying)
}
implicit class IntMtrix2INDArray(val underlying: Seq[Seq[Int]]) extends AnyVal {
implicit class IntMatrix2INDArray(val underlying: Seq[Seq[Int]]) extends AnyVal {
def mkNDArray(ord: NDOrdering): INDArray =
Nd4j.create(underlying.map(_.map(_.toFloat).toArray).toArray, ord.value)
Nd4j.createFromArray(underlying.map(_.toArray).toArray)
def toNDArray: INDArray =
Nd4j.create(underlying.map(_.map(_.toFloat).toArray).toArray)
Nd4j.createFromArray(underlying.map(_.toArray).toArray)
}
implicit class IntArrayMtrix2INDArray(val underlying: Array[Array[Int]]) extends AnyVal {
implicit class IntArrayMatrix2INDArray(val underlying: Array[Array[Int]]) extends AnyVal {
def mkNDArray(ord: NDOrdering): INDArray =
Nd4j.create(underlying.map(_.map(_.toFloat)), ord.value)
def toNDArray: INDArray = Nd4j.create(underlying.map(_.map(_.toFloat)))
Nd4j.createFromArray(underlying.map(_.toArray).toArray)
def toNDArray: INDArray = Nd4j.createFromArray(underlying.map(_.toArray).toArray)
}
implicit class Num2Scalar[T](val underlying: T)(implicit ev: Numeric[T]) {
implicit class LongMatrix2INDArray(val underlying: Seq[Seq[Long]]) extends AnyVal {
def mkNDArray(ord: NDOrdering): INDArray =
Nd4j.createFromArray(underlying.map(_.toArray).toArray)
def toNDArray: INDArray =
Nd4j.createFromArray(underlying.map(_.toArray).toArray)
}
implicit class LongArrayMatrix2INDArray(val underlying: Array[Array[Long]]) extends AnyVal {
def mkNDArray(ord: NDOrdering): INDArray =
Nd4j.createFromArray(underlying.map(_.toArray).toArray)
def toNDArray: INDArray = Nd4j.createFromArray(underlying.map(_.toArray).toArray)
}
/*implicit class Num2Scalar[T](val underlying: T)(implicit ev: Numeric[T]) {
def toScalar: INDArray = Nd4j.scalar(ev.toDouble(underlying))
}*/
implicit class Float2Scalar(val underlying: Float) {
def toScalar: INDArray = Nd4j.scalar(underlying)
}
implicit class Double2Scalar(val underlying: Double) {
def toScalar: INDArray = Nd4j.scalar(underlying)
}
implicit class Long2Scalar(val underlying: Long) {
def toScalar: INDArray = Nd4j.scalar(underlying)
}
implicit class Int2Scalar(val underlying: Int) {
def toScalar: INDArray = Nd4j.scalar(underlying)
}
implicit class Byte2Scalar(val underlying: Byte) {
def toScalar: INDArray = Nd4j.scalar(underlying)
}
implicit class Boolean2Scalar(val underlying: Boolean) {
def toScalar: INDArray = Nd4j.scalar(underlying)
}
implicit class String2Scalar(val underlying: String) {
def toScalar: INDArray = Nd4j.scalar(underlying)
}
implicit def intArray2IndexRangeArray(arr: Array[Int]): Array[IndexRange] =

View File

@ -22,6 +22,9 @@ import org.nd4s.Implicits._
object Evidences {
implicit val double = DoubleNDArrayEvidence
implicit val float = FloatNDArrayEvidence
implicit val int = IntNDArrayEvidence
implicit val long = LongNDArrayEvidence
implicit val byte = ByteNDArrayEvidence
}
object NDArrayEvidence {
@ -342,3 +345,201 @@ case object FloatNDArrayEvidence extends RealNDArrayEvidence[Float] {
override def lessThan(left: Float, right: Float): Boolean = left < right
}
trait IntegerNDArrayEvidence[Value] extends NDArrayEvidence[INDArray, Value] {
def remainder(a: INDArray, that: INDArray): INDArray = a.remainder(that)
def add(a: INDArray, that: INDArray): INDArray = a.add(that)
def sub(a: INDArray, that: INDArray): INDArray = a.sub(that)
def mul(a: INDArray, that: INDArray): INDArray = a.mul(that)
def mmul(a: INDArray, that: INDArray): INDArray = a.mmul(that)
def div(a: INDArray, that: INDArray): INDArray = a.div(that)
def rdiv(a: INDArray, that: INDArray): INDArray = a.rdiv(that)
def addi(a: INDArray, that: INDArray): INDArray = a.addi(that)
def subi(a: INDArray, that: INDArray): INDArray = a.subi(that)
def muli(a: INDArray, that: INDArray): INDArray = a.muli(that)
def mmuli(a: INDArray, that: INDArray): INDArray = a.mmuli(that)
def remainderi(a: INDArray, that: INDArray): INDArray = a.remainder(that)
def remainderi(a: INDArray, that: Number): INDArray = a.remainderi(that)
def divi(a: INDArray, that: INDArray): INDArray = a.divi(that)
def rdivi(a: INDArray, that: INDArray): INDArray = a.rdivi(that)
def remainder(a: INDArray, that: Number): INDArray = a.remainder(that)
def add(a: INDArray, that: Number): INDArray = a.add(that)
def sub(a: INDArray, that: Number): INDArray = a.sub(that)
def mul(a: INDArray, that: Number): INDArray = a.mul(that)
def div(a: INDArray, that: Number): INDArray = a.div(that)
def rdiv(a: INDArray, that: Number): INDArray = a.rdiv(that)
def addi(a: INDArray, that: Number): INDArray = a.addi(that)
def subi(a: INDArray, that: Number): INDArray = a.subi(that)
def muli(a: INDArray, that: Number): INDArray = a.muli(that)
def divi(a: INDArray, that: Number): INDArray = a.divi(that)
def rdivi(a: INDArray, that: Number): INDArray = a.rdivi(that)
def put(a: INDArray, i: Int, element: INDArray): INDArray = a.put(i, element)
def put(a: INDArray, i: Array[Int], element: INDArray): INDArray = a.put(i, element)
def get(a: INDArray, i: INDArrayIndex*): INDArray = a.get(i: _*)
def reshape(a: INDArray, i: Int*): INDArray = a.reshape(i.map(_.toLong): _*)
def linearView(a: INDArray): INDArray = a.reshape(-1)
def dup(a: INDArray): INDArray = a.dup()
def update(a: INDArray, indices: Array[IndexRange], i: Int): INDArray =
a.update(indices, i)
def update(a: INDArray, indices: Array[IndexRange], i: INDArray): INDArray =
a.update(indices, i)
}
case object IntNDArrayEvidence extends IntegerNDArrayEvidence[Int] {
def sum(ndarray: INDArray): Int = ndarray.sumNumber().intValue()
def mean(ndarray: INDArray): Int = ndarray.meanNumber().intValue()
def normMax(ndarray: INDArray): Int = ndarray.normmaxNumber().intValue()
def norm1(ndarray: INDArray): Int = ndarray.norm1Number().intValue()
def norm2(ndarray: INDArray): Int = ndarray.norm2Number().intValue()
def max(ndarray: INDArray): Int = ndarray.maxNumber().intValue()
def min(ndarray: INDArray): Int = ndarray.minNumber().intValue()
def standardDeviation(ndarray: INDArray): Int = ndarray.stdNumber().intValue()
def product(ndarray: INDArray): Int = ndarray.prodNumber().intValue()
def variance(ndarray: INDArray): Int = ndarray.varNumber().intValue()
def get(a: INDArray, i: Int): Int = a.getInt(i)
def get(a: INDArray, i: Int, j: Int): Int = a.getInt(i, j)
def get(a: INDArray, i: Int*): Int = a.getInt(i: _*)
def create(arr: Array[Int]): INDArray = arr.toNDArray
def create(arr: Array[Int], shape: Int*): INDArray = arr.toNDArray
def create(arr: Array[Int], shape: Array[Int], ordering: NDOrdering, offset: Int): INDArray =
arr.mkNDArray(shape, ordering, offset)
def greaterThan(left: Int, right: Int): Boolean = left > right
def lessThan(left: Int, right: Int): Boolean = left < right
}
case object LongNDArrayEvidence extends IntegerNDArrayEvidence[Long] {
def sum(ndarray: INDArray): Long = ndarray.sumNumber().longValue()
def mean(ndarray: INDArray): Long = ndarray.meanNumber().longValue()
def normMax(ndarray: INDArray): Long = ndarray.normmaxNumber().longValue()
def norm1(ndarray: INDArray): Long = ndarray.norm1Number().longValue()
def norm2(ndarray: INDArray): Long = ndarray.norm2Number().longValue()
def max(ndarray: INDArray): Long = ndarray.maxNumber().longValue()
def min(ndarray: INDArray): Long = ndarray.minNumber().longValue()
def standardDeviation(ndarray: INDArray): Long = ndarray.stdNumber().longValue()
def product(ndarray: INDArray): Long = ndarray.prodNumber().longValue()
def variance(ndarray: INDArray): Long = ndarray.varNumber().longValue()
def get(a: INDArray, i: Int): Long = a.getLong(i)
def get(a: INDArray, i: Int, j: Int): Long = a.getLong(i, j)
def get(a: INDArray, i: Int*): Long = a.getLong(i.map(_.toLong): _*)
def create(arr: Array[Long]): INDArray = arr.toNDArray
def create(arr: Array[Long], shape: Int*): INDArray = arr.toNDArray
def create(arr: Array[Long], shape: Array[Int], ordering: NDOrdering, offset: Int): INDArray =
arr.mkNDArray(shape, ordering, offset)
def greaterThan(left: Long, right: Long): Boolean = left > right
def lessThan(left: Long, right: Long): Boolean = left < right
def update(a: INDArray, indices: Array[IndexRange], i: Long): INDArray =
a.update(indices, i)
}
case object ByteNDArrayEvidence extends IntegerNDArrayEvidence[Byte] {
def sum(ndarray: INDArray): Byte = ndarray.sumNumber().byteValue()
def mean(ndarray: INDArray): Byte = ndarray.meanNumber().byteValue()
def normMax(ndarray: INDArray): Byte = ndarray.normmaxNumber().byteValue()
def norm1(ndarray: INDArray): Byte = ndarray.norm1Number().byteValue()
def norm2(ndarray: INDArray): Byte = ndarray.norm2Number().byteValue()
def max(ndarray: INDArray): Byte = ndarray.maxNumber().byteValue()
def min(ndarray: INDArray): Byte = ndarray.minNumber().byteValue()
def standardDeviation(ndarray: INDArray): Byte = ndarray.stdNumber().byteValue()
def product(ndarray: INDArray): Byte = ndarray.prodNumber().byteValue()
def variance(ndarray: INDArray): Byte = ndarray.varNumber().byteValue()
def get(a: INDArray, i: Int): Byte = a.getInt(i).toByte
def get(a: INDArray, i: Int, j: Int): Byte = a.getInt(i, j).toByte
def get(a: INDArray, i: Int*): Byte = a.getInt(i.map(_.toInt): _*).toByte
def create(arr: Array[Byte]): INDArray = arr.toNDArray
def create(arr: Array[Byte], shape: Int*): INDArray = arr.toNDArray
def create(arr: Array[Byte], shape: Array[Int], ordering: NDOrdering, offset: Int): INDArray =
arr.mkNDArray(shape, ordering, offset)
def greaterThan(left: Byte, right: Byte): Boolean = left > right
def lessThan(left: Byte, right: Byte): Boolean = left < right
def update(a: INDArray, indices: Array[IndexRange], i: Byte): INDArray =
a.update(indices, i)
}

View File

@ -48,6 +48,7 @@ trait SliceableNDArray[A <: INDArray] {
ev.create(filtered, targetShape, NDOrdering(underlying.ordering()), 0)
} else {
ev.get(underlying, getINDArrayIndexfrom(target: _*): _*)
}
}
@ -121,9 +122,9 @@ trait SliceableNDArray[A <: INDArray] {
underlying.shape().map(_.toInt).toList
val originalTarget =
if (underlying.isRowVector && target.size == 1)
/*if (underlying.isRowVector && target.size == 1)
IntRange(0) +: target
else
else*/
target
@tailrec

View File

@ -55,4 +55,11 @@ class BitFilterOps(_x: INDArray, len: Int, f: Double => Boolean)
override def op(origin: Float): Float = if (f(origin)) 1 else 0
override def op(origin: Short): Short = if (f(origin)) 1 else 0
override def op(origin: Int): Int = if (f(origin)) 1 else 0
override def op(origin: Long): Long = if (f(origin)) 1 else 0
override def op(origin: String): String = ???
}

View File

@ -54,4 +54,12 @@ class FilterOps(_x: INDArray, len: Int, f: Double => Boolean)
override def op(origin: Float): Float = if (f(origin)) origin else 0
override def op(origin: Short): Short = if (f(origin)) origin else 0
override def op(origin: Int): Int = if (f(origin)) origin else 0
override def op(origin: Long): Long = if (f(origin)) origin else 0
override def op(origin: String): String = ???
}

View File

@ -0,0 +1,474 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4s.ops
import java.util.{ List, Map, Properties }
import org.bytedeco.javacpp.Pointer
import org.nd4j.linalg.api.buffer.{ DataBuffer, DataType, Utf8Buffer }
import org.nd4j.linalg.api.ndarray.{ INDArray, INDArrayStatistics }
import org.nd4j.linalg.api.ops.aggregates.{ Aggregate, Batch }
import org.nd4j.linalg.api.ops._
import org.nd4j.linalg.api.ops.executioner.OpExecutioner
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate
import org.nd4j.linalg.api.ops.impl.summarystats.Variance
import org.nd4j.linalg.api.rng.Random
import org.nd4j.linalg.api.shape.{ LongShapeDescriptor, TadPack }
import org.nd4j.linalg.cache.TADManager
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.profiler.ProfilerConfig
object FunctionalOpExecutioner {
def apply: FunctionalOpExecutioner = new FunctionalOpExecutioner()
}
class FunctionalOpExecutioner extends OpExecutioner {
def isVerbose: Boolean = ???
/**
* This method returns true if debug mode is enabled, false otherwise
*
* @return
*/
def isDebug: Boolean = ???
/**
* This method returns type for this executioner instance
*
* @return
*/
def `type`: OpExecutioner.ExecutionerType = ???
/**
* This method returns opName of the last invoked op
*
* @return
*/
def getLastOp: String = ???
/**
* Execute the operation
*
* @param op the operation to execute
*/
def exec(op: Op): INDArray =
op match {
case op: FilterOps => exec(op.asInstanceOf[FilterOps])
case op: BitFilterOps => exec(op.asInstanceOf[BitFilterOps])
case op: MapOps => exec(op.asInstanceOf[MapOps])
case _ => op.z()
}
def exec(op: FilterOps): INDArray = {
val retVal: INDArray = Nd4j.create(op.x.dataType(), op.x.shape().map(_.toLong): _*)
for (i <- 0 until op.x().length().toInt) {
val filtered = op.x.dataType() match {
case DataType.DOUBLE => op.op(op.x.getDouble(i.toLong))
case DataType.FLOAT => op.op(op.x.getFloat(i.toLong))
case DataType.INT => op.op(op.x.getInt(i))
case DataType.SHORT => op.op(op.x.getInt(i))
case (DataType.LONG) => op.op(op.x.getLong(i.toLong))
}
retVal.putScalar(i, filtered)
}
retVal
}
def exec(op: BitFilterOps): INDArray = {
val retVal: INDArray = Nd4j.create(op.x.dataType(), op.x.shape().map(_.toLong): _*)
for (i <- 0 until op.x().length().toInt) {
val current = if (op.x.dataType() == DataType.DOUBLE) op.x().getDouble(i.toLong) else op.x().getInt(i)
val filtered = op.op(current)
retVal.putScalar(i, filtered)
}
retVal
}
def exec(op: MapOps): INDArray = {
val retVal: INDArray = Nd4j.create(op.x.dataType(), op.x.shape().map(_.toLong): _*)
for (i <- 0 until op.x().length().toInt) {
val current = if (op.x.dataType() == DataType.DOUBLE) op.x().getDouble(i.toLong) else op.x().getInt(i)
val filtered = op.op(current)
retVal.putScalar(i, filtered)
}
retVal
}
/** Execute a TransformOp and return the result
*
* @param op the operation to execute
*/
def execAndReturn(op: TransformOp): TransformOp = ???
/**
* Execute and return the result from an accumulation
*
* @param op the operation to execute
* @return the accumulated result
*/
def execAndReturn(op: ReduceOp): ReduceOp = ???
def execAndReturn(op: Variance): Variance = ???
/** Execute and return the result from an index accumulation
*
* @param op the index accumulation operation to execute
* @return the accumulated index
*/
def execAndReturn(op: IndexAccumulation): IndexAccumulation = ???
/** Execute and return the result from a scalar op
*
* @param op the operation to execute
* @return the accumulated result
*/
def execAndReturn(op: ScalarOp): ScalarOp = ???
/** Execute and return the result from a vector op
*
* @param op */
def execAndReturn(op: BroadcastOp): BroadcastOp = ???
/**
* Execute a reduceOp, possibly along one or more dimensions
*
* @param reduceOp the reduceOp
* @return the reduceOp op
*/
def exec(reduceOp: ReduceOp): INDArray = ???
/**
* Execute a broadcast op, possibly along one or more dimensions
*
* @param broadcast the accumulation
* @return the broadcast op
*/
def exec(broadcast: BroadcastOp): INDArray = ???
/**
* Execute ScalarOp
*
* @param broadcast
* @return
*/
def exec(broadcast: ScalarOp): INDArray = ???
/**
* Execute an variance accumulation op, possibly along one or more dimensions
*
* @param accumulation the accumulation
* @return the accmulation op
*/
def exec(accumulation: Variance): INDArray = ???
/** Execute an index accumulation along one or more dimensions
*
* @param indexAccum the index accumulation operation
* @return result
*/
def exec(indexAccum: IndexAccumulation): INDArray = ???
/**
*
* Execute and return a result
* ndarray from the given op
*
* @param op the operation to execute
* @return the result from the operation
*/
def execAndReturn(op: Op): Op = ???
/**
* Execute MetaOp
*
* @param op
*/
def exec(op: MetaOp): Unit = ???
/**
* Execute GridOp
*
* @param op
*/
def exec(op: GridOp): Unit = ???
/**
*
* @param op
*/
def exec(op: Aggregate): Unit = ???
/**
* This method executes previously built batch
*
* @param batch
*/
def exec[T <: Aggregate](batch: Batch[T]): Unit = ???
/**
* This method takes arbitrary sized list of aggregates,
* and packs them into batches
*
* @param batch
*/
def exec(batch: java.util.List[Aggregate]): Unit = ???
/**
* This method executes specified RandomOp using default RNG available via Nd4j.getRandom()
*
* @param op
*/
def exec(op: RandomOp): INDArray = ???
/**
* This method executes specific RandomOp against specified RNG
*
* @param op
* @param rng
*/
def exec(op: RandomOp, rng: Random): INDArray = ???
/**
* This method return set of key/value and
* key/key/value objects,
* describing current environment
*
* @return
*/
def getEnvironmentInformation: Properties = ???
/**
* This method specifies desired profiling mode
*
* @param mode
*/
@deprecated def setProfilingMode(mode: OpExecutioner.ProfilingMode): Unit = ???
/**
* This method stores specified configuration.
*
* @param config
*/
def setProfilingConfig(config: ProfilerConfig): Unit = ???
/**
* Ths method returns current profiling
*
* @return
*/
@deprecated def getProfilingMode: OpExecutioner.ProfilingMode = ???
/**
* This method returns TADManager instance used for this OpExecutioner
*
* @return
*/
def getTADManager: TADManager = ???
/**
* This method prints out environmental information returned by getEnvironmentInformation() method
*/
def printEnvironmentInformation(): Unit = ???
/**
* This method ensures all operations that supposed to be executed at this moment, are executed.
*/
def push(): Unit = ???
/**
* This method ensures all operations that supposed to be executed at this moment, are executed and finished.
*/
def commit(): Unit = ???
/**
* This method encodes array as thresholds, updating input array at the same time
*
* @param input
* @return encoded array is returned
*/
def thresholdEncode(input: INDArray, threshold: Double): INDArray = ???
def thresholdEncode(input: INDArray, threshold: Double, boundary: Integer): INDArray = ???
/**
* This method decodes thresholds array, and puts it into target array
*
* @param encoded
* @param target
* @return target is returned
*/
def thresholdDecode(encoded: INDArray, target: INDArray): INDArray = ???
/**
* This method returns number of elements affected by encoder
*
* @param indArray
* @param target
* @param threshold
* @return
*/
def bitmapEncode(indArray: INDArray, target: INDArray, threshold: Double): Long = ???
/**
*
* @param indArray
* @param threshold
* @return
*/
def bitmapEncode(indArray: INDArray, threshold: Double): INDArray = ???
/**
*
* @param encoded
* @param target
* @return
*/
def bitmapDecode(encoded: INDArray, target: INDArray): INDArray = ???
/**
* This method returns names of all custom operations available in current backend, and their number of input/output arguments
*
* @return
*/
def getCustomOperations: java.util.Map[String, CustomOpDescriptor] = ???
/**
* This method executes given CustomOp
*
* PLEASE NOTE: You're responsible for input/output validation
*
* @param op
*/
def execAndReturn(op: CustomOp): CustomOp = ???
def exec(op: CustomOp): Array[INDArray] = ???
/**
* This method executes op with given context
*
* @param op
* @param context
* @return method returns output arrays defined within context
*/
def exec(op: CustomOp, context: OpContext): Array[INDArray] = ???
def calculateOutputShape(op: CustomOp): java.util.List[LongShapeDescriptor] = ???
/**
* Equivalent to calli
*/
def allocateOutputArrays(op: CustomOp): Array[INDArray] = ???
def enableDebugMode(reallyEnable: Boolean): Unit = ???
def enableVerboseMode(reallyEnable: Boolean): Unit = ???
def isExperimentalMode: Boolean = ???
def registerGraph(id: Long, graph: Pointer): Unit = ???
def executeGraph(id: Long,
map: java.util.Map[String, INDArray],
reverseMap: java.util.Map[String, Integer]): java.util.Map[String, INDArray] = ???
def forgetGraph(id: Long): Unit = ???
/**
* This method allows to set desired number of elements per thread, for performance optimization purposes.
* I.e. if array contains 2048 elements, and threshold is set to 1024, 2 threads will be used for given op execution.
*
* Default value: 1024
*
* @param threshold
*/
def setElementsThreshold(threshold: Int): Unit = ???
/**
* This method allows to set desired number of sub-arrays per thread, for performance optimization purposes.
* I.e. if matrix has shape of 64 x 128, and threshold is set to 8, each thread will be processing 8 sub-arrays (sure, if you have 8 core cpu).
* If your cpu has, say, 4, cores, only 4 threads will be spawned, and each will process 16 sub-arrays
*
* Default value: 8
*
* @param threshold
*/
def setTadThreshold(threshold: Int): Unit = ???
/**
* This method extracts String from Utf8Buffer
*
* @param buffer
* @param index
* @return
*/
def getString(buffer: Utf8Buffer, index: Long): String = ???
/**
* This method returns OpContext which can be used (and reused) to execute custom ops
*
* @return
*/
def buildContext: OpContext = ???
/**
*
* @param array
*/
def inspectArray(array: INDArray): INDArrayStatistics = ???
/**
* This method returns shapeInfo DataBuffer
*
* @param shape
* @param stride
* @param elementWiseStride
* @param order
* @param dtype
* @return
*/
def createShapeInfo(shape: Array[Long],
stride: Array[Long],
elementWiseStride: Long,
order: Char,
dtype: DataType,
empty: Boolean): DataBuffer = ???
/**
* This method returns host/device tad buffers
*/
def tadShapeInfoAndOffsets(array: INDArray, dimension: Array[Int]): TadPack = ???
/**
* This method returns constant buffer for the given jvm array
*
* @param values
* @return
*/
def createConstantBuffer(values: Array[Long], desiredType: DataType): DataBuffer = ???
def createConstantBuffer(values: Array[Int], desiredType: DataType): DataBuffer = ???
def createConstantBuffer(values: Array[Float], desiredType: DataType): DataBuffer = ???
def createConstantBuffer(values: Array[Double], desiredType: DataType): DataBuffer = ???
@deprecated def scatterUpdate(op: ScatterUpdate.UpdateOp,
array: INDArray,
indices: INDArray,
updates: INDArray,
axis: Array[Int]): Unit = ???
}

View File

@ -34,5 +34,13 @@ trait LeftAssociativeBinaryOp {
def op(origin: Float): Float
def op(origin: Short): Short
def op(origin: Int): Int
def op(origin: Long): Long
def op(origin: String): String
// def op(origin: IComplexNumber): IComplexNumber
}

View File

@ -49,4 +49,11 @@ class MapOps(_x: INDArray, f: Double => Double) extends BaseScalarOp(_x, null, _
override def op(origin: Float): Float = f(origin).toFloat
override def op(origin: Short): Short = f(origin).toShort
override def op(origin: Int): Int = f(origin).toInt
override def op(origin: Long): Long = f(origin).toLong
override def op(origin: String): String = ???
}

View File

@ -29,7 +29,7 @@ class DSLSpec extends FlatSpec with Matchers {
// This test just verifies that an INDArray gets wrapped with an implicit conversion
val nd = Nd4j.create(Array[Float](1, 2), Array(2, 1))
val nd = Nd4j.create(Array[Float](1, 2), Array(2, 1): _*)
val nd1 = nd + 10L // + creates new array, += modifies in place
nd.get(0) should equal(1)

View File

@ -19,7 +19,7 @@ import org.nd4s.Implicits._
import org.scalatest.{ FlatSpec, Matchers }
class NDArrayCollectionAPITest extends FlatSpec with Matchers {
"CollectionLikeNDArray" should "provides filter API" ignore {
"CollectionLikeNDArray" should "provides filter API" in {
val ndArray =
Array(
Array(1, 2, 3),
@ -38,7 +38,48 @@ class NDArrayCollectionAPITest extends FlatSpec with Matchers {
).toNDArray
)
}
it should "provides filter bitmask API" ignore {
"CollectionLikeNDArray from Floats" should "provides filter API" in {
val ndArray =
Array(
Array(1f, 2f, 3f),
Array(4f, 5f, 6f),
Array(7f, 8f, 9f)
).toNDArray
val filtered = ndArray.filter(_ > 3)
assert(
filtered ==
Array(
Array(0f, 0f, 0f),
Array(4f, 5f, 6f),
Array(7f, 8f, 9f)
).toNDArray
)
}
"CollectionLikeNDArray from Long " should "provides filter API" in {
val ndArray =
Array(
Array(1L, 2L, 3L),
Array(4L, 5L, 6L),
Array(7L, 8L, 9L)
).toNDArray
val filtered = ndArray.filter(_ > 3)
assert(
filtered ==
Array(
Array(0L, 0L, 0L),
Array(4L, 5L, 6L),
Array(7L, 8L, 9L)
).toNDArray
)
}
it should "provides filter bitmask API" in {
val ndArray =
Array(
Array(1, 2, 3),
@ -57,7 +98,7 @@ class NDArrayCollectionAPITest extends FlatSpec with Matchers {
).toNDArray
)
}
it should "provides map API" ignore {
it should "provides map API" in {
val ndArray =
Array(
Array(1, 2, 3),

View File

@ -0,0 +1,121 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4s
import org.nd4j.linalg.api.buffer.DataType
import org.nd4s.Implicits._
import org.scalatest.FlatSpec
class NDArrayConstructionTest extends FlatSpec with COrderingForTest {
self: OrderingForTest =>
it should "be able to create 2d matrix filled with integers" in {
val ndArray =
Array(
Array(1, 2),
Array(4, 5),
Array(7, 9)
).mkNDArray(ordering)
assert(DataType.INT == ndArray.dataType())
assert(3 == ndArray.rows())
assert(2 == ndArray.columns())
}
it should "be able to create 2d matrix filled with long integers" in {
val ndArray =
Array(
Array(1L, 2L, 3L),
Array(4L, 5L, 6L),
Array(7L, 8L, 9L)
).mkNDArray(ordering)
assert(DataType.LONG == ndArray.dataType())
assert(3 == ndArray.rows())
assert(3 == ndArray.columns())
}
it should "be able to create 2d matrix filled with float numbers" in {
val ndArray =
Array(
Array(1f, 2f, 3f),
Array(4f, 5f, 6f),
Array(7f, 8f, 9f)
).mkNDArray(ordering)
assert(DataType.FLOAT == ndArray.dataType())
assert(3 == ndArray.rows())
assert(3 == ndArray.columns())
}
it should "be able to create 2d matrix filled with double numbers" in {
val ndArray =
Array(
Array(1d, 2d, 3d),
Array(4d, 5d, 6d),
Array(7d, 8d, 9d)
).mkNDArray(ordering)
assert(DataType.DOUBLE == ndArray.dataType())
assert(3 == ndArray.rows())
assert(3 == ndArray.columns())
}
it should "be able to create vector filled with short integers" in {
val ndArray = Array[Short](1, 2, 3).toNDArray
assert(DataType.SHORT == ndArray.dataType())
assert(1 == ndArray.rows())
assert(3 == ndArray.columns())
}
it should "be able to create vector filled with byte values" in {
val ndArray = Array[Byte](1, 2, 3).toNDArray
assert(DataType.BYTE == ndArray.dataType())
assert(1 == ndArray.rows())
assert(3 == ndArray.columns())
}
it should "be able to create vector filled with boolean values" in {
val ndArray = Array(true, false, true).toNDArray
assert(DataType.BOOL == ndArray.dataType())
assert(1 == ndArray.rows())
assert(3 == ndArray.columns())
}
it should "be able to create vector from integer range" in {
val list = (0 to 9).toNDArray
assert(DataType.INT == list.dataType())
val stepped = list(1 -> 7 by 2)
assert(Array(1, 3, 5).toNDArray == stepped)
assert(DataType.INT == list.dataType())
}
it should "be able to create vector from strings" in {
val oneString = "testme".toScalar
assert("testme" == oneString.getString(0))
assert(DataType.UTF8 == oneString.dataType())
val someStrings = Array[String]("one", "two", "three").toNDArray
assert("one" == someStrings.getString(0))
assert("two" == someStrings.getString(1))
assert("three" == someStrings.getString(2))
assert(DataType.UTF8 == someStrings.dataType())
}
}

View File

@ -48,20 +48,48 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
assert(extracted == expected)
}
it should "be able to extract a part of 2d matrix with offset" ignore { //Ignored AB 2019/05/21 - https://github.com/deeplearning4j/deeplearning4j/issues/7657
val ndArray = (1 to 9).mkNDArray(Array(2, 2), NDOrdering.C, offset = 4)
it should "be able to extract a part of 2d matrix with double data and offset" in {
val ndArray = (1 to 9).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C, offset = 4)
val expectedArray = Array(
Array(5, 6),
Array(7, 8)
Array(5d, 6d),
Array(7d, 8d)
).mkNDArray(ordering)
assert(ndArray == expectedArray)
val expectedSlice = Array(
Array(5),
Array(7)
Array(5d),
Array(7d)
).toNDArray
assert(ndArray(->, 0) == expectedSlice)
assert(expectedArray(->, 0 -> 1) == expectedSlice)
}
it should "be able to extract a part of 2d matrix with integer data" in {
val ndArray = (1 to 9).mkNDArray(Array(2, 2))
val expectedArray = Array(
Array(1, 2),
Array(3, 4)
).mkNDArray(ordering)
assert(ndArray == expectedArray)
val expectedSlice = Array(
Array(1),
Array(3)
).toNDArray
val actualSlice = expectedArray(->, 0 -> 1)
assert(actualSlice == expectedSlice)
}
it should " provide overloaded -> operator providing matrix slices as nd4j" in {
val expectedArray = (1 to 9).mkNDArray(Array(2, 2))
val expectedSlice = expectedArray.slice(0)
val actualSlice = expectedArray(0, ->)
Console.println(expectedSlice)
assert(actualSlice == expectedSlice)
}
it should "be able to extract a part of vertically long matrix in" in {
@ -163,7 +191,7 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
// TODO: fix me. This is about INDArray having to be sliced by LONG indices
// can't find the correct way to fix implicits without breaking other stuff.
it should "be able to extract sub-matrix with index range by step" ignore {
it should "be able to extract sub-matrix with index range by step" in {
val ndArray =
Array(
Array(1, 2, 3),
@ -249,9 +277,10 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
)
}
"num2Scalar" should "convert number to Scalar INDArray" ignore { //Ignored AB 2019/05/21 - https://github.com/deeplearning4j/deeplearning4j/issues/7657
assert(1.toScalar == List(1).toNDArray)
assert(2f.toScalar == List(2).toNDArray)
assert(3d.toScalar == List(3).toNDArray)
"num2Scalar" should "convert number to Scalar INDArray" in {
assert(1.toScalar.data() == List(1).toNDArray.data())
assert(2f.toScalar.data() == List(2).toNDArray.data())
assert(3d.toScalar.data() == List(3).toNDArray.data())
}
}

View File

@ -26,27 +26,27 @@ import org.scalatest.{ FlatSpec, Matchers }
class OperatableNDArrayTest extends FlatSpec with Matchers {
"RichNDArray" should "use the apply method to access values" in {
// -- 2D array
val nd2 = Nd4j.create(Array[Double](1, 2, 3, 4), Array(4, 1))
val nd2 = Nd4j.create(Array[Double](1, 2, 3, 4), Array[Int](1, 4): _*)
nd2.get(0) should be(1)
nd2.get(3, 0) should be(4)
nd2.get(0, 3) should be(4)
// -- 3D array
val nd3 = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2))
val nd3 = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array[Int](2, 2, 2): _*)
nd3.get(0, 0, 0) should be(1)
nd3.get(1, 1, 1) should be(8)
}
it should "use transpose abbreviation" in {
val nd1 = Nd4j.create(Array[Double](1, 2, 3), Array(3, 1))
val nd1 = Nd4j.create(Array[Double](1, 2, 3), Array(3, 1): _*)
nd1.shape should equal(Array(3, 1))
val nd1t = nd1.T
nd1t.shape should equal(Array(1, 3))
}
it should "add correctly" in {
val a = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2))
val a = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2): _*)
val b = a + 100
a.get(0, 0, 0) should be(1)
b.get(0, 0, 0) should be(101)
@ -55,7 +55,7 @@ class OperatableNDArrayTest extends FlatSpec with Matchers {
}
it should "subtract correctly" in {
val a = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2))
val a = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2): _*)
val b = a - 100
a.get(0, 0, 0) should be(1)
b.get(0, 0, 0) should be(-99)
@ -69,7 +69,7 @@ class OperatableNDArrayTest extends FlatSpec with Matchers {
}
it should "divide correctly" in {
val a = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2))
val a = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2): _*)
val b = a / a
a.get(1, 1, 1) should be(8)
b.get(1, 1, 1) should be(1)
@ -78,7 +78,7 @@ class OperatableNDArrayTest extends FlatSpec with Matchers {
}
it should "element-by-element multiply correctly" in {
val a = Nd4j.create(Array[Double](1, 2, 3, 4), Array(4, 1))
val a = Nd4j.create(Array[Double](1, 2, 3, 4), Array(4, 1): _*)
val b = a * a
a.get(3) should be(4) // [1.0, 2.0, 3.0, 4.0
b.get(3) should be(16) // [1.0 ,4.0 ,9.0 ,16.0]
@ -87,7 +87,7 @@ class OperatableNDArrayTest extends FlatSpec with Matchers {
}
it should "use the update method to mutate values" in {
val nd3 = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2))
val nd3 = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2): _*)
nd3(0) = 11
nd3.get(0) should be(11)
@ -125,7 +125,7 @@ class OperatableNDArrayTest extends FlatSpec with Matchers {
val s: String = Nd4j.create(2, 2) + ""
}
"Sum function" should "choose return value depending on INDArray type" ignore {
"Sum function" should "choose return value depending on INDArray type" in {
val ndArray =
Array(
Array(1, 2),