[WIP] Fix compilation after nd4j changes (#37)

* Fix compilation.

* Some tests fixed

* Disable tests temporarily.

* Restored test

* Tests restored.

* Test restored.
master
Alexander Stoyakin 2019-11-08 10:25:44 +02:00 committed by GitHub
parent 0107fb10ab
commit 6958f2ba24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 24 additions and 21 deletions

View File

@ -61,7 +61,7 @@ lazy val commonSettings = Seq(
lazy val publishNexus = Seq( lazy val publishNexus = Seq(
publishTo := { publishTo := {
val nexus = "https://nexus.ci.skymind.io/" val nexus = "https://packages.konduit.ai/"
if (isSnapshot.value) if (isSnapshot.value)
Some("snapshots" at nexus + "content/repositories/maven-snapshots") Some("snapshots" at nexus + "content/repositories/maven-snapshots")
else else

View File

@ -80,7 +80,7 @@ object Implicits {
class IntArray2INDArray(val underlying: Array[Int]) extends AnyVal { class IntArray2INDArray(val underlying: Array[Int]) extends AnyVal {
def mkNDArray(shape: Array[Int], ord: NDOrdering = NDOrdering(Nd4j.order()), offset: Int = 0): INDArray = { def mkNDArray(shape: Array[Int], ord: NDOrdering = NDOrdering(Nd4j.order()), offset: Int = 0): INDArray = {
val strides = Nd4j.getStrides(shape, ord.value) val strides = Nd4j.getStrides(shape, ord.value)
Nd4j.create(underlying, shape.map(_.toLong), strides.map(_.toLong), ord.value, DataType.INT) Nd4j.create(underlying.map(_.toInt), shape.map(_.toLong), strides.map(_.toLong), ord.value, DataType.INT)
} }
def toNDArray: INDArray = Nd4j.createFromArray(underlying: _*) def toNDArray: INDArray = Nd4j.createFromArray(underlying: _*)

View File

@ -170,9 +170,9 @@ class ConstructionTest extends FlatSpec with Matchers {
sd.setTrainingConfig(conf) sd.setTrainingConfig(conf)
sd.fit(new SingletonMultiDataSetIterator(mds), 1) sd.fit(new SingletonMultiDataSetIterator(mds), 1)
w.eval.toDoubleVector.head shouldBe (0.0629 +- 0.0001) w.getArr.get(0) shouldBe (0.0629 +- 0.0001)
w.eval.toDoubleVector.tail.head shouldBe (0.3128 +- 0.0001) w.getArr.get(1) shouldBe (0.3128 +- 0.0001)
w.eval.toDoubleVector.tail.tail.head shouldBe (0.2503 +- 0.0001) w.getArr.get(2) shouldBe (0.2503 +- 0.0001)
//Console.println(w.eval) //Console.println(w.eval)
} }
} }

View File

@ -209,7 +209,7 @@ class MathTest extends FlatSpec with Matchers {
val x = sd.bind(arr) val x = sd.bind(arr)
val y = new SDVariableWrapper(x) val y = new SDVariableWrapper(x)
x.get(SDIndex.point(0)).getArr shouldBe y(0).getArr x.get(SDIndex.point(0)).eval shouldBe y(0).eval
} }
"SDVariable " should "be indexable in 2d" in { "SDVariable " should "be indexable in 2d" in {
@ -221,7 +221,7 @@ class MathTest extends FlatSpec with Matchers {
x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval
val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval val slice1 = x.get(SDIndex.interval(0L, 2L), SDIndex.all()).eval
val slice2 = x(0 :: 2, ---).eval val slice2 = x(0 :: 2, ---).eval
slice1 shouldBe slice2 slice1 shouldBe slice2
} }
@ -237,10 +237,10 @@ class MathTest extends FlatSpec with Matchers {
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval
x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 :: 2, 0, 0).eval x.get(SDIndex.interval(0L, 2L), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 :: 2, 0, 0).eval
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 :: 2, x.get(SDIndex.interval(0L, 2L), SDIndex.interval(0L, 1L), SDIndex.interval(0L, 2L)).eval shouldBe x(0 :: 2,
0 :: 1, 0 :: 1,
0 :: 2).eval 0 :: 2).eval
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 :: 2, 0 :: 1, ---).eval x.get(SDIndex.interval(0L, 2L), SDIndex.interval(0L, 1L), SDIndex.all()).eval shouldBe x(0 :: 2, 0 :: 1, ---).eval
} }
} }

View File

@ -60,11 +60,11 @@ class SameDiffTest extends FlatSpec with Matchers {
sd.associateArrayWithVariable(inputArr, input) sd.associateArrayWithVariable(inputArr, input)
sd.associateArrayWithVariable(labelArr, label) sd.associateArrayWithVariable(labelArr, label)
val result: INDArray = sd.execAndEndResult val result = sd.output(null: java.util.Map[String, org.nd4j.linalg.api.ndarray.INDArray], "loss")
assertEquals(1, result.length) assertEquals(1, result.values().size())
val emptyMap = new HashMap[String, INDArray]() val emptyMap = new HashMap[String, INDArray]()
sd.execBackwards(emptyMap) sd.output(emptyMap, "loss")
} }
"SameDiff" should "run test dense layer forward pass" in { "SameDiff" should "run test dense layer forward pass" in {
@ -84,7 +84,7 @@ class SameDiffTest extends FlatSpec with Matchers {
val expMmul = iInput.mmul(iWeights) val expMmul = iInput.mmul(iWeights)
val expZ = expMmul.addRowVector(iBias) val expZ = expMmul.addRowVector(iBias)
val expOut = Transforms.sigmoid(expZ, true) val expOut = Transforms.sigmoid(expZ, true)
sd.exec(new HashMap[String, INDArray](), sd.outputs) sd.output(new HashMap[String, INDArray](), "mmul", "out", "bias", "add")
assertEquals(expMmul, mmul.getArr) assertEquals(expMmul, mmul.getArr)
assertEquals(expZ, z.getArr) assertEquals(expZ, z.getArr)
assertEquals(expOut, out.getArr) assertEquals(expOut, out.getArr)
@ -109,15 +109,18 @@ class SameDiffTest extends FlatSpec with Matchers {
.dataSetFeatureMapping("in", "in2") .dataSetFeatureMapping("in", "in2")
.skipBuilderValidation(true) .skipBuilderValidation(true)
.build .build
sd.setTrainingConfig(c)
sd.fit(new SingletonMultiDataSetIterator(new MultiDataSet(Array[INDArray](inArr, inArr2), null)), 1) val data = new HashMap[String, INDArray]()
val out = tanh.eval data.put("in", Nd4j.randn(1, 3))
data.put("in2", Nd4j.randn(3, 4))
in.convertToConstant in.convertToConstant
val out2 = tanh.eval val out = sd.output(data, "tanh")
val out2 = sd.output(data, "tanh")
assertEquals(out, out2) assertEquals(out, out2)
assertEquals(VariableType.CONSTANT, in.getVariableType) assertEquals(VariableType.CONSTANT, in.getVariableType)
assertEquals(inArr, in.getArr) assertEquals(inArr, in.getArr)
//Sanity check on fitting: //Sanity check on fitting:
sd.fit(new SingletonMultiDataSetIterator(new MultiDataSet(Array[INDArray](inArr2), null)), 1) sd.setTrainingConfig(c)
sd.fit(new SingletonMultiDataSetIterator(new MultiDataSet(Array[INDArray](inArr, inArr2), null)), 1)
} }
} }