parent
b2f86ab8a8
commit
76aab75174
|
@ -1,95 +0,0 @@
|
||||||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
~
|
|
||||||
~ This program and the accompanying materials are made available under the
|
|
||||||
~ terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~
|
|
||||||
~ Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ License for the specific language governing permissions and limitations
|
|
||||||
~ under the License.
|
|
||||||
~
|
|
||||||
~ SPDX-License-Identifier: Apache-2.0
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
<parent>
|
|
||||||
<artifactId>nd4j-camel-routes</artifactId>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<artifactId>nd4j-kafka_2.11</artifactId>
|
|
||||||
<packaging>jar</packaging>
|
|
||||||
|
|
||||||
<name>nd4j-kafka</name>
|
|
||||||
<url>https://deeplearning4j.org</url>
|
|
||||||
|
|
||||||
<properties>
|
|
||||||
<!-- Default scala versions, may be overwritten by build profiles -->
|
|
||||||
<scala.version>2.11.12</scala.version>
|
|
||||||
<scala.binary.version>2.11</scala.binary.version>
|
|
||||||
</properties>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-api</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.scala-lang</groupId>
|
|
||||||
<artifactId>scala-library</artifactId>
|
|
||||||
<version>${scala.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>io.netty</groupId>
|
|
||||||
<artifactId>netty</artifactId>
|
|
||||||
<version>${netty.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>net.jpountz.lz4</groupId>
|
|
||||||
<artifactId>lz4</artifactId>
|
|
||||||
<version>${lz4.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.xerial.snappy</groupId>
|
|
||||||
<artifactId>snappy-java</artifactId>
|
|
||||||
<version>${snappy.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.camel</groupId>
|
|
||||||
<artifactId>camel-kafka</artifactId>
|
|
||||||
<version>${camel.version}</version>
|
|
||||||
<exclusions>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>org.slf4j</groupId>
|
|
||||||
<artifactId>slf4j-log4j12</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>log4j</groupId>
|
|
||||||
<artifactId>log4j</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
</exclusions>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-common-tests</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>testresources</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,83 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.kafka;
|
|
||||||
|
|
||||||
import org.apache.camel.CamelContext;
|
|
||||||
import org.apache.camel.impl.DefaultCamelContext;
|
|
||||||
import org.junit.After;
|
|
||||||
import org.junit.Before;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.BaseND4JTest;
|
|
||||||
import org.nd4j.camel.kafka.KafkaConnectionInformation;
|
|
||||||
import org.nd4j.camel.kafka.Nd4jKafkaConsumer;
|
|
||||||
import org.nd4j.camel.kafka.Nd4jKafkaProducer;
|
|
||||||
import org.nd4j.camel.kafka.Nd4jKafkaRoute;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Created by agibsonccc on 7/19/16.
|
|
||||||
*/
|
|
||||||
public class Nd4jKafkaRouteTest extends BaseND4JTest {
|
|
||||||
private EmbeddedKafkaCluster kafka;
|
|
||||||
private EmbeddedZookeeper zk;
|
|
||||||
private CamelContext camelContext;
|
|
||||||
public final static String TOPIC = "nd4jtest";
|
|
||||||
public final static String GROUP_ID = "nd4j";
|
|
||||||
private KafkaConnectionInformation connectionInformation;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Before
|
|
||||||
public void before() throws Exception {
|
|
||||||
zk = new EmbeddedZookeeper(TestUtils.getAvailablePort());
|
|
||||||
zk.startup();
|
|
||||||
kafka = new EmbeddedKafkaCluster(zk.getConnection());
|
|
||||||
kafka.startup();
|
|
||||||
kafka.createTopics(TOPIC);
|
|
||||||
camelContext = new DefaultCamelContext();
|
|
||||||
camelContext.start();
|
|
||||||
connectionInformation = KafkaConnectionInformation.builder().groupId(GROUP_ID).topicName(TOPIC)
|
|
||||||
.zookeeperHost("localhost").zookeeperPort(zk.getPort()).kafkaBrokerList(kafka.getBrokerList())
|
|
||||||
.build();
|
|
||||||
camelContext.addRoutes(Nd4jKafkaRoute.builder().kafkaConnectionInformation(connectionInformation).build());
|
|
||||||
}
|
|
||||||
|
|
||||||
@After
|
|
||||||
public void after() throws Exception {
|
|
||||||
if (kafka != null)
|
|
||||||
kafka.shutdown();
|
|
||||||
if (zk != null)
|
|
||||||
zk.shutdown();
|
|
||||||
if (camelContext != null)
|
|
||||||
camelContext.stop();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testKafkaRoute() throws Exception {
|
|
||||||
Nd4jKafkaProducer kafkaProducer = Nd4jKafkaProducer.builder().camelContext(camelContext)
|
|
||||||
.connectionInformation(connectionInformation).build();
|
|
||||||
kafkaProducer.publish(Nd4j.create(4));
|
|
||||||
Nd4jKafkaConsumer consumer = Nd4jKafkaConsumer.builder().camelContext(camelContext)
|
|
||||||
.connectionInformation(connectionInformation).build();
|
|
||||||
assertEquals(Nd4j.create(4), consumer.receive());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,171 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
~
|
|
||||||
~ This program and the accompanying materials are made available under the
|
|
||||||
~ terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~
|
|
||||||
~ Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ License for the specific language governing permissions and limitations
|
|
||||||
~ under the License.
|
|
||||||
~
|
|
||||||
~ SPDX-License-Identifier: Apache-2.0
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
<parent>
|
|
||||||
<artifactId>nd4j-serde</artifactId>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<artifactId>nd4j-gson</artifactId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-api</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>provided</scope>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.google.code.gson</groupId>
|
|
||||||
<artifactId>gson</artifactId>
|
|
||||||
<version>${gson.version}</version>
|
|
||||||
<scope>provided</scope>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>junit</groupId>
|
|
||||||
<artifactId>junit</artifactId>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-common-tests</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>testresources</id>
|
|
||||||
</profile>
|
|
||||||
|
|
||||||
<profile>
|
|
||||||
<id>nd4j-tests-cpu</id>
|
|
||||||
<activation>
|
|
||||||
<activeByDefault>false</activeByDefault>
|
|
||||||
</activation>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-native</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
|
||||||
<configuration>
|
|
||||||
<environmentVariables>
|
|
||||||
<LD_LIBRARY_PATH>${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/
|
|
||||||
</LD_LIBRARY_PATH>
|
|
||||||
</environmentVariables>
|
|
||||||
<testSourceDirectory>src/test/java</testSourceDirectory>
|
|
||||||
<includes>
|
|
||||||
<include>*.java</include>
|
|
||||||
<include>**/*.java</include>
|
|
||||||
<include>**/Test*.java</include>
|
|
||||||
<include>**/*Test.java</include>
|
|
||||||
<include>**/*TestCase.java</include>
|
|
||||||
</includes>
|
|
||||||
<junitArtifactName>junit:junit</junitArtifactName>
|
|
||||||
<systemPropertyVariables>
|
|
||||||
<org.nd4j.linalg.defaultbackend>org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
|
||||||
</org.nd4j.linalg.defaultbackend>
|
|
||||||
<org.nd4j.linalg.tests.backendstorun>org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
|
||||||
</org.nd4j.linalg.tests.backendstorun>
|
|
||||||
</systemPropertyVariables>
|
|
||||||
<!--
|
|
||||||
Maximum heap size was set to 8g, as a minimum required value for tests run.
|
|
||||||
Depending on a build machine, default value is not always enough.
|
|
||||||
|
|
||||||
For testing large zoo models, this may not be enough (so comment it out).
|
|
||||||
-->
|
|
||||||
<argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g</argLine>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
</profile>
|
|
||||||
|
|
||||||
<profile>
|
|
||||||
<id>nd4j-tests-cuda</id>
|
|
||||||
<activation>
|
|
||||||
<activeByDefault>false</activeByDefault>
|
|
||||||
</activation>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-cuda-10.2</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.maven.surefire</groupId>
|
|
||||||
<artifactId>surefire-junit47</artifactId>
|
|
||||||
<version>2.19.1</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
<configuration>
|
|
||||||
<environmentVariables>
|
|
||||||
<LD_LIBRARY_PATH>
|
|
||||||
${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/
|
|
||||||
</LD_LIBRARY_PATH>
|
|
||||||
</environmentVariables>
|
|
||||||
<testSourceDirectory>src/test/java</testSourceDirectory>
|
|
||||||
<includes>
|
|
||||||
<include>*.java</include>
|
|
||||||
<include>**/*.java</include>
|
|
||||||
<include>**/Test*.java</include>
|
|
||||||
<include>**/*Test.java</include>
|
|
||||||
<include>**/*TestCase.java</include>
|
|
||||||
</includes>
|
|
||||||
<junitArtifactName>junit:junit</junitArtifactName>
|
|
||||||
<systemPropertyVariables>
|
|
||||||
<org.nd4j.linalg.defaultbackend>org.nd4j.linalg.jcublas.JCublasBackend
|
|
||||||
</org.nd4j.linalg.defaultbackend>
|
|
||||||
<org.nd4j.linalg.tests.backendstorun>org.nd4j.linalg.jcublas.JCublasBackend
|
|
||||||
</org.nd4j.linalg.tests.backendstorun>
|
|
||||||
</systemPropertyVariables>
|
|
||||||
<!--
|
|
||||||
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
|
||||||
Depending on a build machine, default value is not always enough.
|
|
||||||
-->
|
|
||||||
<argLine>-Ddtype=float -Xmx6g</argLine>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
|
|
||||||
</project>
|
|
|
@ -1,101 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.serde.gson;
|
|
||||||
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.BaseND4JTest;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
public class GsonDeserializationUtilsTest extends BaseND4JTest {
|
|
||||||
@Test
|
|
||||||
public void deserializeRawJson_PassInInRank3Array_ExpectCorrectDeserialization() {
|
|
||||||
String serializedRawArray = "[[[1.00, 11.00, 3.00],\n" + "[13.00, 5.00, 15.00],\n" + "[7.00, 17.00, 9.00]]]";
|
|
||||||
INDArray expectedArray = buildExpectedArray(1, 1, 3, 3);
|
|
||||||
|
|
||||||
INDArray indArray = GsonDeserializationUtils.deserializeRawJson(serializedRawArray);
|
|
||||||
|
|
||||||
assertEquals(expectedArray, indArray);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void deserializeRawJson_ArrayHasOnlyOneRowWithColumns_ExpectCorrectDeserialization() {
|
|
||||||
String serializedRawArray = "[1.00, 11.00, 3.00]";
|
|
||||||
INDArray expectedArray = Nd4j.create(new double[] {1, 11, 3}).castTo(Nd4j.defaultFloatingPointType());
|
|
||||||
|
|
||||||
INDArray indArray = GsonDeserializationUtils.deserializeRawJson(serializedRawArray);
|
|
||||||
|
|
||||||
assertEquals(expectedArray, indArray);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void deserializeRawJson_ArrayIsRankFive_ExpectCorrectDeserialization() {
|
|
||||||
String serializedRawArray = "[[[[[1.00, 11.00],\n" + " [3.00, 13.00]],\n" + " [[5.00, 15.00],\n"
|
|
||||||
+ " [7.00, 17.00]]],\n" + " [[[9.00, 1.00],\n" + " [11.00, 3.00]],\n"
|
|
||||||
+ " [[13.00, 5.00],\n" + " [15.00, 7.00]]],\n" + " [[[17.00, 9.00],\n"
|
|
||||||
+ " [1.00, 11.00]],\n" + " [[3.00, 13.00],\n" + " [5.00, 15.00]]]],\n"
|
|
||||||
+ " [[[[7.00, 17.00],\n" + " [9.00, 1.00]],\n" + " [[11.00, 3.00],\n"
|
|
||||||
+ " [13.00, 5.00]]],\n" + " [[[15.00, 7.00],\n" + " [17.00, 9.00]],\n"
|
|
||||||
+ " [[1.00, 11.00],\n" + " [3.00, 13.00]]],\n" + " [[[5.00, 15.00],\n"
|
|
||||||
+ " [7.00, 17.00]],\n" + " [[9.00, 1.00],\n" + " [11.00, 3.00]]]],\n"
|
|
||||||
+ " [[[[13.00, 5.00],\n" + " [15.00, 7.00]],\n" + " [[17.00, 9.00],\n"
|
|
||||||
+ " [1.00, 11.00]]],\n" + " [[[3.00, 13.00],\n" + " [5.00, 15.00]],\n"
|
|
||||||
+ " [[7.00, 17.00],\n" + " [9.00, 1.00]]],\n" + " [[[11.00, 3.00],\n"
|
|
||||||
+ " [13.00, 5.00]],\n" + " [[15.00, 7.00],\n" + " [17.00, 9.00]]]]]";
|
|
||||||
INDArray expectedArray = buildExpectedArray(8, 3, 3, 2, 2, 2);
|
|
||||||
|
|
||||||
INDArray array = GsonDeserializationUtils.deserializeRawJson(serializedRawArray);
|
|
||||||
|
|
||||||
assertEquals(expectedArray, array);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testSimpleVector() {
|
|
||||||
INDArray arr = Nd4j.linspace(1, 4, 4, Nd4j.defaultFloatingPointType()).reshape(4);
|
|
||||||
INDArray out = GsonDeserializationUtils.deserializeRawJson(arr.toString());
|
|
||||||
assertEquals(arr, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void deserializeRawJson_HaveCommaInsideNumbers_ExpectCorrectDeserialization() {
|
|
||||||
String serializedRawArray =
|
|
||||||
"[[1.00, 1100.00, 3.00],\n" + "[13.00, 5.00, 15591.00],\n" + "[7000.00, 17.00, 9.00]]";
|
|
||||||
INDArray expectedArray = Nd4j.create(new double[] {1, 1100, 3, 13, 5, 15591, 7000, 17, 9}, new int[] {3, 3}).castTo(Nd4j.defaultFloatingPointType());
|
|
||||||
|
|
||||||
INDArray indArray = GsonDeserializationUtils.deserializeRawJson(serializedRawArray);
|
|
||||||
|
|
||||||
assertEquals(expectedArray, indArray);
|
|
||||||
}
|
|
||||||
|
|
||||||
private INDArray buildExpectedArray(int numberOfTripletRows, int... shape) {
|
|
||||||
INDArray expectedArray = Nd4j.create(3 * numberOfTripletRows, 3);
|
|
||||||
for (int i = 0; i < numberOfTripletRows; i++) {
|
|
||||||
int index = 3 * i;
|
|
||||||
expectedArray.putRow(index, Nd4j.create(new double[] {1, 11, 3}));
|
|
||||||
expectedArray.putRow(index + 1, Nd4j.create(new double[] {13, 5, 15}));
|
|
||||||
expectedArray.putRow(index + 2, Nd4j.create(new double[] {7, 17, 9}));
|
|
||||||
}
|
|
||||||
|
|
||||||
return expectedArray.reshape(shape);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -185,16 +185,6 @@
|
||||||
<artifactId>nd4j-kryo_2.11</artifactId>
|
<artifactId>nd4j-kryo_2.11</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-kafka_2.11</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-gson</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-common</artifactId>
|
<artifactId>nd4j-common</artifactId>
|
||||||
|
|
Loading…
Reference in New Issue