Merge pull request #8788 from KonduitAI/master
Update master based on latest development workmaster
commit
9f523d6811
|
@ -196,6 +196,11 @@ public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
|
||||||
// 0-> [0,0,0], 1-> [1,0,0], 2-> [2,0,0], 3-> [0,1,0] etc
|
// 0-> [0,0,0], 1-> [1,0,0], 2-> [2,0,0], 3-> [0,1,0] etc
|
||||||
//Based on: Nd4j Shape.ind2sub
|
//Based on: Nd4j Shape.ind2sub
|
||||||
|
|
||||||
|
int countNon1 = 0;
|
||||||
|
for( int i : numValuesPerParam)
|
||||||
|
if(i > 1)
|
||||||
|
countNon1++;
|
||||||
|
|
||||||
int denom = product;
|
int denom = product;
|
||||||
int num = candidateIdx;
|
int num = candidateIdx;
|
||||||
int[] index = new int[numValuesPerParam.length];
|
int[] index = new int[numValuesPerParam.length];
|
||||||
|
@ -209,12 +214,11 @@ public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
|
||||||
//Now: convert indexes to values in range [0,1]
|
//Now: convert indexes to values in range [0,1]
|
||||||
//min value -> 0
|
//min value -> 0
|
||||||
//max value -> 1
|
//max value -> 1
|
||||||
double[] out = new double[numValuesPerParam.length];
|
double[] out = new double[countNon1];
|
||||||
for (int i = 0; i < out.length; i++) {
|
int outIdx = 0;
|
||||||
if (numValuesPerParam[i] <= 1)
|
for (int i = 0; i < numValuesPerParam.length; i++) {
|
||||||
out[i] = 0.0;
|
if (numValuesPerParam[i] > 1){
|
||||||
else {
|
out[outIdx++] = index[i] / ((double) (numValuesPerParam[i] - 1));
|
||||||
out[i] = index[i] / ((double) (numValuesPerParam[i] - 1));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ import org.deeplearning4j.arbiter.DL4JConfiguration;
|
||||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||||
import org.deeplearning4j.arbiter.TestUtils;
|
import org.deeplearning4j.arbiter.TestUtils;
|
||||||
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
|
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
|
||||||
|
import org.deeplearning4j.arbiter.conf.updater.NesterovsSpace;
|
||||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||||
import org.deeplearning4j.arbiter.layers.*;
|
import org.deeplearning4j.arbiter.layers.*;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
||||||
|
@ -80,6 +81,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
|
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
|
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.lang.reflect.Field;
|
import java.lang.reflect.Field;
|
||||||
|
@ -767,4 +769,52 @@ public class TestMultiLayerSpace extends BaseDL4JTest {
|
||||||
assertEquals(expCandidates, count);
|
assertEquals(expCandidates, count);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGridCandidateGenerator(){
|
||||||
|
ParameterSpace<Integer> layerSizeParam = new DiscreteParameterSpace<>(32, 48, 64);
|
||||||
|
ParameterSpace<Double> learningRateParam = new DiscreteParameterSpace<>(0.005, 0.007, 0.01);
|
||||||
|
|
||||||
|
MultiLayerSpace hyperParamaterSpace = new MultiLayerSpace.Builder()
|
||||||
|
.seed(12345)
|
||||||
|
.biasInit(1)
|
||||||
|
.l2(1e-4)
|
||||||
|
.updater(new NesterovsSpace(learningRateParam))
|
||||||
|
.addLayer(new DenseLayerSpace.Builder().nIn(10).nOut(layerSizeParam)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.build())
|
||||||
|
.addLayer(new DenseLayerSpace.Builder().nIn(layerSizeParam).nOut(layerSizeParam)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.build())
|
||||||
|
.addLayer(new OutputLayerSpace.Builder()
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MSE)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nIn(layerSizeParam).nOut(10).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
CandidateGenerator candidateGenerator = new GridSearchCandidateGenerator(hyperParamaterSpace, 30, GridSearchCandidateGenerator.Mode.Sequential, null);
|
||||||
|
// CandidateGenerator candidateGenerator = new RandomSearchGenerator(hyperParamaterSpace);
|
||||||
|
|
||||||
|
Set<Pair<Double,Integer>> expCandidates = new HashSet<>();
|
||||||
|
for(Double d : new double[]{0.005, 0.007, 0.01}){
|
||||||
|
for(int i : new int[]{32, 48, 64}){
|
||||||
|
expCandidates.add(new Pair<>(d, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Set<Pair<Double,Integer>> actCandidates = new HashSet<>();
|
||||||
|
while(candidateGenerator.hasMoreCandidates()) {
|
||||||
|
Candidate<DL4JConfiguration> conf = candidateGenerator.getCandidate();
|
||||||
|
MultiLayerConfiguration mlc = conf.getValue().getMultiLayerConfiguration();
|
||||||
|
FeedForwardLayer ffl = ((FeedForwardLayer) mlc.getConf(0).getLayer());
|
||||||
|
// System.out.println(ffl.getIUpdater() + ", " + ffl.getNOut());
|
||||||
|
actCandidates.add(new Pair<>(ffl.getIUpdater().getLearningRate(0,0), (int)ffl.getNOut()));
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(expCandidates, actCandidates);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,6 +55,10 @@ import static org.junit.Assert.assertEquals;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class ArbiterCLIRunnerTest extends BaseDL4JTest {
|
public class ArbiterCLIRunnerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCliRunner() throws Exception {
|
public void testCliRunner() throws Exception {
|
||||||
|
@ -67,7 +71,7 @@ public class ArbiterCLIRunnerTest extends BaseDL4JTest {
|
||||||
.l2(new ContinuousParameterSpace(0.0001, 0.01))
|
.l2(new ContinuousParameterSpace(0.0001, 0.01))
|
||||||
.addLayer(new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2,10))
|
.addLayer(new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2,10))
|
||||||
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
|
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
|
||||||
.build(),new IntegerParameterSpace(1,2),true) //1-2 identical layers (except nIn)
|
.build())
|
||||||
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
|
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.numEpochs(3).build();
|
.numEpochs(3).build();
|
||||||
|
|
|
@ -37,6 +37,10 @@
|
||||||
<artifactId>nd4j-api</artifactId>
|
<artifactId>nd4j-api</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>ch.qos.logback</groupId>
|
||||||
|
<artifactId>logback-classic</artifactId>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j;
|
package org.deeplearning4j;
|
||||||
|
|
||||||
|
import ch.qos.logback.classic.LoggerContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
@ -31,6 +32,8 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.profiler.ProfilerConfig;
|
import org.nd4j.linalg.profiler.ProfilerConfig;
|
||||||
|
import org.slf4j.ILoggerFactory;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.lang.management.ManagementFactory;
|
import java.lang.management.ManagementFactory;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -86,12 +89,12 @@ public abstract class BaseDL4JTest {
|
||||||
return getDataType();
|
return getDataType();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected Boolean integrationTest;
|
protected static Boolean integrationTest;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return True if integration tests maven profile is enabled, false otherwise.
|
* @return True if integration tests maven profile is enabled, false otherwise.
|
||||||
*/
|
*/
|
||||||
public boolean isIntegrationTests(){
|
public static boolean isIntegrationTests(){
|
||||||
if(integrationTest == null){
|
if(integrationTest == null){
|
||||||
String prop = System.getenv("DL4J_INTEGRATION_TESTS");
|
String prop = System.getenv("DL4J_INTEGRATION_TESTS");
|
||||||
integrationTest = Boolean.parseBoolean(prop);
|
integrationTest = Boolean.parseBoolean(prop);
|
||||||
|
@ -104,7 +107,7 @@ public abstract class BaseDL4JTest {
|
||||||
* This can be used to dynamically skip integration tests when the integration test profile is not enabled.
|
* This can be used to dynamically skip integration tests when the integration test profile is not enabled.
|
||||||
* Note that the integration test profile is not enabled by default - "integration-tests" profile
|
* Note that the integration test profile is not enabled by default - "integration-tests" profile
|
||||||
*/
|
*/
|
||||||
public void skipUnlessIntegrationTests(){
|
public static void skipUnlessIntegrationTests(){
|
||||||
assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests());
|
assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,6 +142,15 @@ public abstract class BaseDL4JTest {
|
||||||
//Not really safe to continue testing under this situation... other tests will likely fail with obscure
|
//Not really safe to continue testing under this situation... other tests will likely fail with obscure
|
||||||
// errors that are hard to track back to this
|
// errors that are hard to track back to this
|
||||||
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
|
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
|
||||||
|
System.out.println("Open workspace leaked from test! Exiting - " + currWS.getId() + ", isOpen = " + currWS.isScopeActive() + " - " + currWS);
|
||||||
|
System.out.flush();
|
||||||
|
//Try to flush logs also:
|
||||||
|
try{ Thread.sleep(1000); } catch (InterruptedException e){ }
|
||||||
|
ILoggerFactory lf = LoggerFactory.getILoggerFactory();
|
||||||
|
if( lf instanceof LoggerContext){
|
||||||
|
((LoggerContext)lf).stop();
|
||||||
|
}
|
||||||
|
try{ Thread.sleep(1000); } catch (InterruptedException e){ }
|
||||||
System.exit(1);
|
System.exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -164,6 +164,20 @@
|
||||||
<artifactId>oshi-core</artifactId>
|
<artifactId>oshi-core</artifactId>
|
||||||
<version>${oshi.version}</version>
|
<version>${oshi.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<!-- Test scope reflections to ensure all classes extend base test class -->
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.reflections</groupId>
|
||||||
|
<artifactId>reflections</artifactId>
|
||||||
|
<version>${reflections.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
<exclusions>
|
||||||
|
<exclusion>
|
||||||
|
<groupId>com.google.code.findbugs</groupId>
|
||||||
|
<artifactId>*</artifactId>
|
||||||
|
</exclusion>
|
||||||
|
</exclusions>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.reflections.Reflections;
|
||||||
|
import org.reflections.scanners.MethodAnnotationsScanner;
|
||||||
|
import org.reflections.util.ClasspathHelper;
|
||||||
|
import org.reflections.util.ConfigurationBuilder;
|
||||||
|
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
|
||||||
|
* extends BaseDl4JTest - either directly or indirectly.
|
||||||
|
* Other than a small set of exceptions, all tests must extend this
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class AssertTestsExtendBaseClass extends BaseDL4JTest {
|
||||||
|
|
||||||
|
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
|
||||||
|
private static final Set<Class<?>> exclusions = new HashSet<>();
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void checkTestClasses(){
|
||||||
|
|
||||||
|
Reflections reflections = new Reflections(new ConfigurationBuilder()
|
||||||
|
.setUrls(ClasspathHelper.forPackage("org.deeplearning4j"))
|
||||||
|
.setScanners(new MethodAnnotationsScanner()));
|
||||||
|
Set<Method> methods = reflections.getMethodsAnnotatedWith(Test.class);
|
||||||
|
Set<Class<?>> s = new HashSet<>();
|
||||||
|
for(Method m : methods){
|
||||||
|
s.add(m.getDeclaringClass());
|
||||||
|
}
|
||||||
|
|
||||||
|
List<Class<?>> l = new ArrayList<>(s);
|
||||||
|
Collections.sort(l, new Comparator<Class<?>>() {
|
||||||
|
@Override
|
||||||
|
public int compare(Class<?> aClass, Class<?> t1) {
|
||||||
|
return aClass.getName().compareTo(t1.getName());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
int count = 0;
|
||||||
|
for(Class<?> c : l){
|
||||||
|
if(!BaseDL4JTest.class.isAssignableFrom(c) && !exclusions.contains(c)){
|
||||||
|
log.error("Test {} does not extend BaseDL4JTest (directly or indirectly). All tests must extend this class for proper memory tracking and timeouts", c);
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assertEquals("Number of tests not extending BaseDL4JTest", 0, count);
|
||||||
|
}
|
||||||
|
}
|
|
@ -17,7 +17,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
import java.util.concurrent.CountDownLatch;
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
|
||||||
@Ignore
|
@Ignore
|
||||||
public class RandomTests {
|
public class RandomTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testReproduce() throws Exception {
|
public void testReproduce() throws Exception {
|
||||||
|
|
|
@ -16,11 +16,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.datasets;
|
package org.deeplearning4j.datasets;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.fetchers.Cifar10Fetcher;
|
import org.deeplearning4j.datasets.fetchers.Cifar10Fetcher;
|
||||||
import org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher;
|
import org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
public class TestDataSets {
|
public class TestDataSets extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testTinyImageNetExists() throws Exception {
|
public void testTinyImageNetExists() throws Exception {
|
||||||
|
|
|
@ -1006,9 +1006,9 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
|
||||||
for (RecordMetaData m : meta) {
|
for (RecordMetaData m : meta) {
|
||||||
Record r = csv.loadFromMetaData(m);
|
Record r = csv.loadFromMetaData(m);
|
||||||
INDArray row = ds.getFeatures().getRow(i);
|
INDArray row = ds.getFeatures().getRow(i);
|
||||||
if(i <= 3) {
|
// if(i <= 3) {
|
||||||
System.out.println(m.getLocation() + "\t" + r.getRecord() + "\t" + row);
|
// System.out.println(m.getLocation() + "\t" + r.getRecord() + "\t" + row);
|
||||||
}
|
// }
|
||||||
|
|
||||||
for (int j = 0; j < 4; j++) {
|
for (int j = 0; j < 4; j++) {
|
||||||
double exp = r.getRecord().get(j).toDouble();
|
double exp = r.getRecord().get(j).toDouble();
|
||||||
|
@ -1017,7 +1017,7 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
System.out.println();
|
// System.out.println();
|
||||||
|
|
||||||
DataSet fromMeta = rrdsi.loadFromMetaData(meta);
|
DataSet fromMeta = rrdsi.loadFromMetaData(meta);
|
||||||
assertEquals(ds, fromMeta);
|
assertEquals(ds, fromMeta);
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.datasets.iterator;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import lombok.var;
|
import lombok.var;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.tools.SimpleVariableGenerator;
|
import org.deeplearning4j.datasets.iterator.tools.SimpleVariableGenerator;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
@ -31,7 +32,7 @@ import static org.junit.Assert.assertNotNull;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class DummyBlockDataSetIteratorTests {
|
public class DummyBlockDataSetIteratorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBlock_1() throws Exception {
|
public void testBlock_1() throws Exception {
|
||||||
|
|
|
@ -18,13 +18,14 @@ package org.deeplearning4j.datasets.iterator;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
|
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class JointMultiDataSetIteratorTests {
|
public class JointMultiDataSetIteratorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test (timeout = 20000L)
|
@Test (timeout = 20000L)
|
||||||
public void testJMDSI_1() {
|
public void testJMDSI_1() {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.datasets.iterator;
|
package org.deeplearning4j.datasets.iterator;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.loader.DataSetLoaderIterator;
|
import org.deeplearning4j.datasets.iterator.loader.DataSetLoaderIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.loader.MultiDataSetLoaderIterator;
|
import org.deeplearning4j.datasets.iterator.loader.MultiDataSetLoaderIterator;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
@ -37,7 +38,7 @@ import java.util.Random;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
public class LoaderIteratorTests {
|
public class LoaderIteratorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDSLoaderIter(){
|
public void testDSLoaderIter(){
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.nn.graph.graphnodes;
|
package org.deeplearning4j.nn.graph.graphnodes;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.api.MaskState;
|
import org.deeplearning4j.nn.api.MaskState;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -54,7 +55,7 @@ import java.util.Map;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
public class TestGraphNodes {
|
public class TestGraphNodes extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMergeNode() {
|
public void testMergeNode() {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.layers;
|
package org.deeplearning4j.nn.layers;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
|
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
|
||||||
|
@ -32,7 +33,7 @@ import java.util.Arrays;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
public class RepeatVectorTest {
|
public class RepeatVectorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
private int REPEAT = 4;
|
private int REPEAT = 4;
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.layers.convolution;
|
package org.deeplearning4j.nn.layers.convolution;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
|
@ -37,7 +38,7 @@ import static org.junit.Assert.assertTrue;
|
||||||
/**
|
/**
|
||||||
* @author Max Pumperla
|
* @author Max Pumperla
|
||||||
*/
|
*/
|
||||||
public class Convolution3DTest {
|
public class Convolution3DTest extends BaseDL4JTest {
|
||||||
|
|
||||||
private int nExamples = 1;
|
private int nExamples = 1;
|
||||||
private int nChannelsOut = 1;
|
private int nChannelsOut = 1;
|
||||||
|
|
|
@ -47,8 +47,7 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
import static org.junit.Assert.*;
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
public class EmbeddingLayerTest extends BaseDL4JTest {
|
public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@ -725,4 +724,79 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
assertEquals(new ActivationIdentity(), l2.getActivationFn());
|
assertEquals(new ActivationIdentity(), l2.getActivationFn());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEmbeddingWeightInit(){
|
||||||
|
// https://github.com/eclipse/deeplearning4j/issues/8663
|
||||||
|
//The embedding layer weight initialization should be independent of the vocabulary size (nIn setting)
|
||||||
|
|
||||||
|
for(WeightInit wi : new WeightInit[]{WeightInit.XAVIER, WeightInit.RELU, WeightInit.XAVIER_UNIFORM, WeightInit.LECUN_NORMAL}) {
|
||||||
|
|
||||||
|
for (boolean seq : new boolean[]{false, true}) {
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.seed(12345)
|
||||||
|
.list()
|
||||||
|
.layer(seq ?
|
||||||
|
new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100).nOut(100).build() :
|
||||||
|
new EmbeddingLayer.Builder().weightInit(wi).nIn(100).nOut(100).build())
|
||||||
|
.build();
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
|
||||||
|
.seed(12345)
|
||||||
|
.list()
|
||||||
|
.layer(seq ?
|
||||||
|
new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100).nOut(100).build() :
|
||||||
|
new EmbeddingLayer.Builder().weightInit(wi).nIn(100).nOut(100).build())
|
||||||
|
.build();
|
||||||
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||||
|
net2.init();
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder()
|
||||||
|
.seed(12345)
|
||||||
|
.list()
|
||||||
|
.layer(seq ?
|
||||||
|
new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100000).nOut(100).build() :
|
||||||
|
new EmbeddingLayer.Builder().weightInit(wi).nIn(100000).nOut(100).build())
|
||||||
|
.build();
|
||||||
|
MultiLayerNetwork net3 = new MultiLayerNetwork(conf3);
|
||||||
|
net3.init();
|
||||||
|
|
||||||
|
INDArray p1 = net.params();
|
||||||
|
INDArray p2 = net2.params();
|
||||||
|
INDArray p3 = net3.params();
|
||||||
|
boolean eq = p1.equalsWithEps(p2, 1e-4);
|
||||||
|
String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi;
|
||||||
|
assertTrue(str + " p1/p2 params not equal", eq);
|
||||||
|
|
||||||
|
double m1 = p1.meanNumber().doubleValue();
|
||||||
|
double s1 = p1.stdNumber().doubleValue();
|
||||||
|
|
||||||
|
double m3 = p3.meanNumber().doubleValue();
|
||||||
|
double s3 = p3.stdNumber().doubleValue();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
assertEquals(str, m1, m3, 0.1);
|
||||||
|
assertEquals(str, s1, s3, 0.1);
|
||||||
|
|
||||||
|
double re = relErr(s1, s3);
|
||||||
|
assertTrue(str + " - " + re, re < 0.05);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static double relErr(double d1, double d2){
|
||||||
|
if(d1 == 0.0 && d2 == 0.0)
|
||||||
|
return 0.0;
|
||||||
|
return Math.abs(d1 - d2) / (Math.abs(d1) + Math.abs(d2));
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.layers.ocnn;
|
package org.deeplearning4j.nn.layers.ocnn;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||||
import org.deeplearning4j.gradientcheck.GradientCheckUtil;
|
import org.deeplearning4j.gradientcheck.GradientCheckUtil;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
@ -51,7 +52,7 @@ import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
|
||||||
public class OCNNOutputLayerTest {
|
public class OCNNOutputLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
private static final boolean PRINT_RESULTS = true;
|
private static final boolean PRINT_RESULTS = true;
|
||||||
private static final boolean RETURN_ON_FIRST_FAILURE = false;
|
private static final boolean RETURN_ON_FIRST_FAILURE = false;
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.layers.recurrent;
|
package org.deeplearning4j.nn.layers.recurrent;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
||||||
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
||||||
|
@ -27,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
public class TestRecurrentWeightInit {
|
public class TestRecurrentWeightInit extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRWInit() {
|
public void testRWInit() {
|
||||||
|
|
|
@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
@ -136,6 +137,11 @@ public class SameDiffCustomLayerTests extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
private class ValidatingSameDiffVertex extends SameDiffVertex {
|
private class ValidatingSameDiffVertex extends SameDiffVertex {
|
||||||
|
@Override
|
||||||
|
public GraphVertex clone() {
|
||||||
|
return new ValidatingSameDiffVertex();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
|
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
|
||||||
return vertexInputs[0];
|
return vertexInputs[0];
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.samediff.testlayers;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
|
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
|
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
|
||||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
||||||
|
@ -74,4 +75,9 @@ public class SameDiffDenseVertex extends SameDiffVertex {
|
||||||
public char paramReshapeOrder(String paramName){
|
public char paramReshapeOrder(String paramName){
|
||||||
return 'f'; //To match DL4J DenseLayer - for easy comparison
|
return 'f'; //To match DL4J DenseLayer - for easy comparison
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public GraphVertex clone() {
|
||||||
|
return new SameDiffDenseVertex(nIn, nOut, activation, weightInit);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.layers.samediff.testlayers;
|
package org.deeplearning4j.nn.layers.samediff.testlayers;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaVertex;
|
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaVertex;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.misc;
|
package org.deeplearning4j.nn.misc;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -35,7 +36,7 @@ import static org.junit.Assert.assertArrayEquals;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
@Ignore //Ignored due to very large memory requirements
|
@Ignore //Ignored due to very large memory requirements
|
||||||
public class LargeNetTest {
|
public class LargeNetTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Ignore
|
@Ignore
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint;
|
import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint;
|
||||||
import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
|
import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.AttentionVertex;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
|
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
|
||||||
|
@ -35,6 +36,7 @@ import org.deeplearning4j.nn.weights.WeightInitDistribution;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -44,6 +46,9 @@ import org.nd4j.linalg.learning.config.RmsProp;
|
||||||
import org.nd4j.linalg.learning.config.Sgd;
|
import org.nd4j.linalg.learning.config.Sgd;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -565,4 +570,99 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
||||||
assertEquals("Incorrect number of inputs!", 5, newGraph.layerInputSize(afterPoolName));
|
assertEquals("Incorrect number of inputs!", 5, newGraph.layerInputSize(afterPoolName));
|
||||||
newGraph.output(input);
|
newGraph.output(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTransferLearningSameDiffLayersGraph(){
|
||||||
|
|
||||||
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
|
||||||
|
.graphBuilder()
|
||||||
|
.addInputs("in")
|
||||||
|
.layer("l0", new LSTM.Builder().nIn(5).nOut(5).build(), "in")
|
||||||
|
.layer("l1", new RecurrentAttentionLayer.Builder().nHeads(1).headSize(5).nIn(5).nOut(5).build(), "l0")
|
||||||
|
.layer("out", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
|
||||||
|
.setOutputs("out")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ComputationGraph cg = new ComputationGraph(conf);
|
||||||
|
cg.init();
|
||||||
|
|
||||||
|
INDArray arr = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||||
|
INDArray out = cg.output(arr)[0];
|
||||||
|
|
||||||
|
|
||||||
|
ComputationGraph cg2 = new TransferLearning.GraphBuilder(cg).removeVertexAndConnections("out")
|
||||||
|
.fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build())
|
||||||
|
.removeVertexAndConnections("out")
|
||||||
|
.addLayer("newOut", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
|
||||||
|
.setOutputs("newOut")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
cg2.output(arr);
|
||||||
|
|
||||||
|
Map<String,INDArray> m = new HashMap<>(cg.paramTable());
|
||||||
|
m.put("newOut_W", m.remove("out_W"));
|
||||||
|
m.put("newOut_b", m.remove("out_b"));
|
||||||
|
cg2.setParamTable(m);
|
||||||
|
|
||||||
|
Map<String,INDArray> p1 = cg.paramTable();
|
||||||
|
Map<String,INDArray> p2 = cg2.paramTable();
|
||||||
|
for(String s : p1.keySet()){
|
||||||
|
INDArray i1 = p1.get(s);
|
||||||
|
INDArray i2 = p2.get(s.replaceAll("out", "newOut"));
|
||||||
|
assertEquals(s, i1, i2);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray out2 = cg2.outputSingle(arr);
|
||||||
|
assertEquals(out, out2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTransferLearningSameDiffLayersGraphVertex(){
|
||||||
|
|
||||||
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
|
||||||
|
.graphBuilder()
|
||||||
|
.addInputs("in")
|
||||||
|
.layer("l0", new LSTM.Builder().nIn(5).nOut(5).build(), "in")
|
||||||
|
.addVertex("l1", new AttentionVertex.Builder().nHeads(1).headSize(5).nInKeys(5).nInQueries(5).nInValues(5).nOut(5).build(), "l0", "l0", "l0")
|
||||||
|
.layer("out", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
|
||||||
|
.setOutputs("out")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ComputationGraph cg = new ComputationGraph(conf);
|
||||||
|
cg.init();
|
||||||
|
|
||||||
|
INDArray arr = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||||
|
INDArray out = cg.output(arr)[0];
|
||||||
|
|
||||||
|
|
||||||
|
ComputationGraph cg2 = new TransferLearning.GraphBuilder(cg).removeVertexAndConnections("out")
|
||||||
|
.fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build())
|
||||||
|
.removeVertexAndConnections("out")
|
||||||
|
.addLayer("newOut", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
|
||||||
|
.setOutputs("newOut")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
cg2.output(arr);
|
||||||
|
|
||||||
|
Map<String,INDArray> m = new HashMap<>(cg.paramTable());
|
||||||
|
m.put("newOut_W", m.remove("out_W"));
|
||||||
|
m.put("newOut_b", m.remove("out_b"));
|
||||||
|
cg2.setParamTable(m);
|
||||||
|
|
||||||
|
Map<String,INDArray> p1 = cg.paramTable();
|
||||||
|
Map<String,INDArray> p2 = cg2.paramTable();
|
||||||
|
for(String s : p1.keySet()){
|
||||||
|
INDArray i1 = p1.get(s);
|
||||||
|
INDArray i2 = p2.get(s.replaceAll("out", "newOut"));
|
||||||
|
assertEquals(s, i1, i2);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray out2 = cg2.outputSingle(arr);
|
||||||
|
assertEquals(out, out2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,6 +41,7 @@ import org.deeplearning4j.nn.weights.WeightInitRelu;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -48,6 +49,8 @@ import org.nd4j.linalg.learning.config.*;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -689,4 +692,51 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
assertEquals("Incorrect number of inputs!", 5, newNet.layerInputSize(2));
|
assertEquals("Incorrect number of inputs!", 5, newNet.layerInputSize(2));
|
||||||
newNet.output(input);
|
newNet.output(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTransferLearningSameDiffLayers(){
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.updater(new Adam(0.01))
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.list()
|
||||||
|
.layer(new LSTM.Builder().nOut(8).build())
|
||||||
|
.layer( new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build())
|
||||||
|
.layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build())
|
||||||
|
.layer(new OutputLayer.Builder().nOut(2).activation(Activation.SOFTMAX)
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
|
.setInputType(InputType.recurrent(4))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 4, 5);
|
||||||
|
INDArray out = net.output(in);
|
||||||
|
|
||||||
|
MultiLayerNetwork net2 = new TransferLearning.Builder(net)
|
||||||
|
.fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build())
|
||||||
|
.removeLayersFromOutput(1)
|
||||||
|
.addLayer(new OutputLayer.Builder().nIn(4).nOut(2).activation(Activation.SOFTMAX)
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
net2.setParam("3_W", net.getParam("3_W"));
|
||||||
|
net2.setParam("3_b", net.getParam("3_b"));
|
||||||
|
|
||||||
|
Map<String,INDArray> p1 = net.paramTable();
|
||||||
|
Map<String,INDArray> p2 = net2.paramTable();
|
||||||
|
for(String s : p1.keySet()){
|
||||||
|
INDArray i1 = p1.get(s);
|
||||||
|
INDArray i2 = p2.get(s);
|
||||||
|
assertEquals(s, i1, i2);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray out2 = net2.output(in);
|
||||||
|
|
||||||
|
assertEquals(out, out2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.updater.custom;
|
package org.deeplearning4j.nn.updater.custom;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
||||||
|
@ -35,7 +36,7 @@ import static org.junit.Assert.assertTrue;
|
||||||
/**
|
/**
|
||||||
* Created by Alex on 09/05/2017.
|
* Created by Alex on 09/05/2017.
|
||||||
*/
|
*/
|
||||||
public class TestCustomUpdater {
|
public class TestCustomUpdater extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCustomUpdater() {
|
public void testCustomUpdater() {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.weights;
|
package org.deeplearning4j.nn.weights;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.distribution.*;
|
import org.deeplearning4j.nn.conf.distribution.*;
|
||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
@ -40,7 +41,7 @@ import static org.junit.Assert.*;
|
||||||
*
|
*
|
||||||
* @author Christian Skarby
|
* @author Christian Skarby
|
||||||
*/
|
*/
|
||||||
public class LegacyWeightInitTest {
|
public class LegacyWeightInitTest extends BaseDL4JTest {
|
||||||
|
|
||||||
private RandomFactory prevFactory;
|
private RandomFactory prevFactory;
|
||||||
private final static int SEED = 666;
|
private final static int SEED = 666;
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.weights;
|
package org.deeplearning4j.nn.weights;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals;
|
||||||
*
|
*
|
||||||
* @author Christian Skarby
|
* @author Christian Skarby
|
||||||
*/
|
*/
|
||||||
public class WeightInitIdentityTest {
|
public class WeightInitIdentityTest extends BaseDL4JTest {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Test identity mapping for 1d convolution
|
* Test identity mapping for 1d convolution
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package org.deeplearning4j.optimizer.listener;
|
package org.deeplearning4j.optimizer.listener;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener;
|
import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
@ -7,7 +8,7 @@ import org.junit.Test;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
public class ScoreStatTest {
|
public class ScoreStatTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testScoreStatSmall() {
|
public void testScoreStatSmall() {
|
||||||
CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat();
|
CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat();
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.regressiontest;
|
package org.deeplearning4j.regressiontest;
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -36,7 +37,7 @@ import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.Assert.assertNotNull;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
public class MiscRegressionTests {
|
public class MiscRegressionTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testFrozen() throws Exception {
|
public void testFrozen() throws Exception {
|
||||||
|
|
|
@ -28,6 +28,7 @@ import java.util.*;
|
||||||
@lombok.Builder
|
@lombok.Builder
|
||||||
public class FastText implements WordVectors, Serializable {
|
public class FastText implements WordVectors, Serializable {
|
||||||
|
|
||||||
|
private final static String METHOD_NOT_AVAILABLE = "This method is available for text (.vec) models only - binary (.bin) model currently loaded";
|
||||||
// Mandatory
|
// Mandatory
|
||||||
@Getter private String inputFile;
|
@Getter private String inputFile;
|
||||||
@Getter private String outputFile;
|
@Getter private String outputFile;
|
||||||
|
@ -219,6 +220,7 @@ public class FastText implements WordVectors, Serializable {
|
||||||
|
|
||||||
public void loadBinaryModel(String modelPath) {
|
public void loadBinaryModel(String modelPath) {
|
||||||
fastTextImpl.loadModel(modelPath);
|
fastTextImpl.loadModel(modelPath);
|
||||||
|
|
||||||
modelLoaded = true;
|
modelLoaded = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -368,14 +370,12 @@ public class FastText implements WordVectors, Serializable {
|
||||||
return words.contains(word);
|
return words.contains(word);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected transient ModelUtils modelUtils;
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Collection<String> wordsNearest(INDArray words, int top) {
|
public Collection<String> wordsNearest(INDArray words, int top) {
|
||||||
if (modelVectorsLoaded) {
|
if (modelVectorsLoaded) {
|
||||||
return word2Vec.wordsNearest(words, top);
|
return word2Vec.wordsNearest(words, top);
|
||||||
}
|
}
|
||||||
return modelUtils.wordsNearest(words, top);
|
throw new IllegalStateException(METHOD_NOT_AVAILABLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -383,7 +383,7 @@ public class FastText implements WordVectors, Serializable {
|
||||||
if (modelVectorsLoaded) {
|
if (modelVectorsLoaded) {
|
||||||
return word2Vec.wordsNearestSum(words, top);
|
return word2Vec.wordsNearestSum(words, top);
|
||||||
}
|
}
|
||||||
return modelUtils.wordsNearestSum(words, top);
|
throw new IllegalStateException(METHOD_NOT_AVAILABLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -391,7 +391,7 @@ public class FastText implements WordVectors, Serializable {
|
||||||
if (modelVectorsLoaded) {
|
if (modelVectorsLoaded) {
|
||||||
return word2Vec.wordsNearestSum(word, n);
|
return word2Vec.wordsNearestSum(word, n);
|
||||||
}
|
}
|
||||||
return modelUtils.wordsNearestSum(word, n);
|
throw new IllegalStateException(METHOD_NOT_AVAILABLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -400,7 +400,7 @@ public class FastText implements WordVectors, Serializable {
|
||||||
if (modelVectorsLoaded) {
|
if (modelVectorsLoaded) {
|
||||||
return word2Vec.wordsNearestSum(positive, negative, top);
|
return word2Vec.wordsNearestSum(positive, negative, top);
|
||||||
}
|
}
|
||||||
return modelUtils.wordsNearestSum(positive, negative, top);
|
throw new IllegalStateException(METHOD_NOT_AVAILABLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -408,7 +408,7 @@ public class FastText implements WordVectors, Serializable {
|
||||||
if (modelVectorsLoaded) {
|
if (modelVectorsLoaded) {
|
||||||
return word2Vec.accuracy(questions);
|
return word2Vec.accuracy(questions);
|
||||||
}
|
}
|
||||||
return modelUtils.accuracy(questions);
|
throw new IllegalStateException(METHOD_NOT_AVAILABLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -425,7 +425,7 @@ public class FastText implements WordVectors, Serializable {
|
||||||
if (modelVectorsLoaded) {
|
if (modelVectorsLoaded) {
|
||||||
return word2Vec.similarWordsInVocabTo(word, accuracy);
|
return word2Vec.similarWordsInVocabTo(word, accuracy);
|
||||||
}
|
}
|
||||||
return modelUtils.similarWordsInVocabTo(word, accuracy);
|
throw new IllegalStateException(METHOD_NOT_AVAILABLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -433,7 +433,7 @@ public class FastText implements WordVectors, Serializable {
|
||||||
if (modelVectorsLoaded) {
|
if (modelVectorsLoaded) {
|
||||||
return word2Vec.wordsNearest(positive, negative, top);
|
return word2Vec.wordsNearest(positive, negative, top);
|
||||||
}
|
}
|
||||||
return modelUtils.wordsNearest(positive, negative, top);
|
throw new IllegalStateException(METHOD_NOT_AVAILABLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -442,7 +442,7 @@ public class FastText implements WordVectors, Serializable {
|
||||||
if (modelVectorsLoaded) {
|
if (modelVectorsLoaded) {
|
||||||
return word2Vec.wordsNearest(word,n);
|
return word2Vec.wordsNearest(word,n);
|
||||||
}
|
}
|
||||||
return modelUtils.wordsNearestSum(word, n);
|
throw new IllegalStateException(METHOD_NOT_AVAILABLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -451,7 +451,7 @@ public class FastText implements WordVectors, Serializable {
|
||||||
if (modelVectorsLoaded) {
|
if (modelVectorsLoaded) {
|
||||||
return word2Vec.similarity(word, word2);
|
return word2Vec.similarity(word, word2);
|
||||||
}
|
}
|
||||||
return modelUtils.similarity(word, word2);
|
throw new IllegalStateException(METHOD_NOT_AVAILABLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -464,7 +464,6 @@ public class FastText implements WordVectors, Serializable {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setModelUtils(ModelUtils utils) {
|
public void setModelUtils(ModelUtils utils) {
|
||||||
this.modelUtils = utils;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -121,6 +121,21 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
assertEquals("__label__soccer", label);
|
assertEquals("__label__soccer", label);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test(expected = IllegalStateException.class)
|
||||||
|
public void testIllegalState() {
|
||||||
|
String text = "I like soccer";
|
||||||
|
|
||||||
|
FastText fastText = new FastText(supModelFile);
|
||||||
|
assertEquals(48, fastText.vocab().numWords());
|
||||||
|
assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
|
||||||
|
|
||||||
|
double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582};
|
||||||
|
assertArrayEquals(expected, fastText.getWordVector("association"), 1e-4);
|
||||||
|
|
||||||
|
String label = fastText.predict(text);
|
||||||
|
fastText.wordsNearest("test",1);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPredictProbability() {
|
public void testPredictProbability() {
|
||||||
String text = "I like soccer";
|
String text = "I like soccer";
|
||||||
|
|
|
@ -427,7 +427,8 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
|
||||||
if(!disconnected.isEmpty() && !allowNoOutput){ //If allowing no output: by definition we have disconnected vertices
|
if(!disconnected.isEmpty() && !allowNoOutput){ //If allowing no output: by definition we have disconnected vertices
|
||||||
throw new IllegalStateException("Invalid configuration: disconnected vertices found - " + disconnected
|
throw new IllegalStateException("Invalid configuration: disconnected vertices found - " + disconnected
|
||||||
+ ". Disconnected vertices are those that do not connect to either another vertex, and are also"
|
+ ". Disconnected vertices are those that do not connect to either another vertex, and are also"
|
||||||
+ " not a network output. To disable this error (i.e., allow network configurations with" +
|
+ " not a network output. This vertex can be set as an output using setOutputs(String...). "
|
||||||
|
+ "To disable this error (i.e., allow network configurations with" +
|
||||||
" disconnected vertices) use GraphBuilder.allowDisconnected(true)");
|
" disconnected vertices) use GraphBuilder.allowDisconnected(true)");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,6 +72,20 @@ public class AttentionVertex extends SameDiffVertex {
|
||||||
this.weightInit = builder.weightInit;
|
this.weightInit = builder.weightInit;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AttentionVertex clone() {
|
||||||
|
AttentionVertex av = new AttentionVertex();
|
||||||
|
av.nInKeys = nInKeys;
|
||||||
|
av.nInValues = nInValues;
|
||||||
|
av.nInQueries = nInQueries;
|
||||||
|
av.nOut = nOut;
|
||||||
|
av.headSize = headSize;
|
||||||
|
av.nHeads = nHeads;
|
||||||
|
av.projectInput = projectInput;
|
||||||
|
av.weightInit = weightInit;
|
||||||
|
return av;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
|
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
|
||||||
InputType.InputTypeRecurrent queries = (InputType.InputTypeRecurrent) vertexInputs[0];
|
InputType.InputTypeRecurrent queries = (InputType.InputTypeRecurrent) vertexInputs[0];
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
||||||
|
import org.deeplearning4j.nn.params.EmbeddingLayerParamInitializer;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
|
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
|
||||||
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
|
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
|
||||||
|
@ -79,7 +80,7 @@ public class EmbeddingLayer extends FeedForwardLayer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ParamInitializer initializer() {
|
public ParamInitializer initializer() {
|
||||||
return DefaultParamInitializer.getInstance();
|
return EmbeddingLayerParamInitializer.getInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
import org.deeplearning4j.nn.params.EmbeddingLayerParamInitializer;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
|
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
|
||||||
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
|
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
|
||||||
|
@ -92,7 +92,7 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ParamInitializer initializer() {
|
public ParamInitializer initializer() {
|
||||||
return DefaultParamInitializer.getInstance();
|
return EmbeddingLayerParamInitializer.getInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,11 +16,13 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.layers.samediff;
|
package org.deeplearning4j.nn.conf.layers.samediff;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
import java.lang.reflect.InvocationTargetException;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,6 +77,15 @@ public abstract class SameDiffLambdaVertex extends SameDiffVertex {
|
||||||
//No op, for lambda vertex
|
//No op, for lambda vertex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public GraphVertex clone() {
|
||||||
|
try {
|
||||||
|
return getClass().getConstructor().newInstance();
|
||||||
|
} catch (Exception e){
|
||||||
|
throw new RuntimeException("Unable to create new instance of class " + getClass().getName() + " from no-arg constructor");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
protected VertexInputs getInputs(SameDiff sd) {
|
protected VertexInputs getInputs(SameDiff sd) {
|
||||||
if (inputs == null) {
|
if (inputs == null) {
|
||||||
inputs = new VertexInputs(sd);
|
inputs = new VertexInputs(sd);
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.layers.samediff.SameDiffGraphVertex;
|
import org.deeplearning4j.nn.layers.samediff.SameDiffGraphVertex;
|
||||||
|
@ -36,6 +37,7 @@ import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
@ -99,11 +101,6 @@ public abstract class SameDiffVertex extends GraphVertex implements TrainingConf
|
||||||
return vertexParams;
|
return vertexParams;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public GraphVertex clone() {
|
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long numParams(boolean backprop) {
|
public long numParams(boolean backprop) {
|
||||||
SDLayerParams params = getVertexParams();
|
SDLayerParams params = getVertexParams();
|
||||||
|
|
|
@ -3394,7 +3394,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setParamTable(@NonNull Map<String, INDArray> paramTable) {
|
public void setParamTable(@NonNull Map<String, INDArray> paramTable) {
|
||||||
Preconditions.checkArgument(paramTable.keySet().equals(paramTable().keySet()), "Cannot set param table: parameter set keys are not equal");
|
Map<String,INDArray> m = paramTable();
|
||||||
|
Preconditions.checkArgument(paramTable.keySet().equals(m.keySet()), "Cannot set param table: parameter set keys are not equal");
|
||||||
Map<String,INDArray> current = paramTable();
|
Map<String,INDArray> current = paramTable();
|
||||||
//Check shapes before doing partial assigment to avoid leaving net in incorrect state
|
//Check shapes before doing partial assigment to avoid leaving net in incorrect state
|
||||||
for(String s : current.keySet()){
|
for(String s : current.keySet()){
|
||||||
|
|
|
@ -237,9 +237,16 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setParams(INDArray params) {
|
public void setParams(INDArray params) {
|
||||||
if (params != null) {
|
if(this.params == null && params == null)
|
||||||
throw new UnsupportedOperationException("Not supported");
|
return;
|
||||||
}
|
if(this.params == null)
|
||||||
|
throw new IllegalStateException("Cannot set parameters of length " + params.length() + " to a layer with no parameters");
|
||||||
|
if(params == null)
|
||||||
|
throw new IllegalStateException("Cannot set null parameters");
|
||||||
|
|
||||||
|
Preconditions.checkState(this.params.length() == params.length(), "Cannot assign parameter vector of length %s to a layer with %s parameters",
|
||||||
|
params.length(), this.params.length());
|
||||||
|
this.params.assign(params);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void setParams(INDArray params, char order) {
|
protected void setParams(INDArray params, char order) {
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.nn.params;
|
||||||
|
|
||||||
|
import lombok.val;
|
||||||
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parameter initializer for EmbeddingLayer and EmbeddingSequenceLayer
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
public class EmbeddingLayerParamInitializer extends DefaultParamInitializer {
|
||||||
|
|
||||||
|
private static final EmbeddingLayerParamInitializer INSTANCE = new EmbeddingLayerParamInitializer();
|
||||||
|
|
||||||
|
public static EmbeddingLayerParamInitializer getInstance() {
|
||||||
|
return INSTANCE;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
protected INDArray createWeightMatrix(long nIn, long nOut, IWeightInit weightInit,
|
||||||
|
INDArray weightParamView, boolean initializeParameters) {
|
||||||
|
val shape = new long[] {nIn, nOut};
|
||||||
|
|
||||||
|
if (initializeParameters) {
|
||||||
|
INDArray ret = weightInit.init(1, //Fan in - note that fanIn=1 for embedding layer... if we used layer nIn (i.e., vocab size) the init would depend on vocab size (which doesn't make sense)
|
||||||
|
nOut, //Fan out
|
||||||
|
shape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, weightParamView);
|
||||||
|
return ret;
|
||||||
|
} else {
|
||||||
|
return WeightInitUtil.reshapeWeights(shape, weightParamView);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.optimize.listeners;
|
||||||
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
|
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
|
||||||
import it.unimi.dsi.fastutil.ints.IntArrayList;
|
import it.unimi.dsi.fastutil.ints.IntArrayList;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.api.Model;
|
import org.deeplearning4j.nn.api.Model;
|
||||||
import org.deeplearning4j.optimize.api.BaseTrainingListener;
|
import org.deeplearning4j.optimize.api.BaseTrainingListener;
|
||||||
|
@ -32,6 +33,7 @@ import java.io.Serializable;
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class CollectScoresListener extends BaseTrainingListener implements Serializable {
|
public class CollectScoresListener extends BaseTrainingListener implements Serializable {
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,15 @@
|
||||||
|
|
||||||
#DL4J Integration Tests
|
#DL4J and SameDiff Integration Tests
|
||||||
|
|
||||||
These tests are designed to check a number of aspects of DL4J:
|
These tests are designed to check a number of aspects of DL4J and SameDiff:
|
||||||
1. Predictions
|
1. Predictions (i.e., network output)
|
||||||
2. Training (training curves, parameters, gradient calculation)
|
2. Training (training curves, parameters, gradient calculation)
|
||||||
3. Evaluation
|
3. Evaluation (accuracy, etc)
|
||||||
4. Model serialization
|
4. Model serialization (saving + loading models)
|
||||||
5. Overfitting sanity checks
|
5. Overfitting sanity checks (make sure we can overfit a single example)
|
||||||
6. Data pipelines
|
6. Data pipelines
|
||||||
7. Evaluation classes
|
7. Parallel Wrapper
|
||||||
8. Parallel Wrapper
|
8. Validating conditions that should always hold (frozen layer params don't change, for example)
|
||||||
9. Validating conditions that should always hold (frozen layer params don't change, for example)
|
|
||||||
|
|
||||||
|
|
||||||
They are designed for the following purposes:
|
They are designed for the following purposes:
|
||||||
|
@ -19,32 +18,46 @@ They are designed for the following purposes:
|
||||||
3. Detecting significant differences between CPU and CUDA backends
|
3. Detecting significant differences between CPU and CUDA backends
|
||||||
4. Validating implementation via sanity checks on training - i.e., can we overfit a single example?
|
4. Validating implementation via sanity checks on training - i.e., can we overfit a single example?
|
||||||
5. Checking networks and data pipelines on real-world scale data and nets
|
5. Checking networks and data pipelines on real-world scale data and nets
|
||||||
6. Operating as fully automated pre-release checks (replacing previously used manual checks)
|
6. Operating as fully automated pre-release checks (replacing manual sanity checks)
|
||||||
|
|
||||||
## Types of Tests
|
## Main Classes
|
||||||
|
|
||||||
The integration tests are set up to be able to run multiple tests on each network configuration.
|
Explanation of the main classes:
|
||||||
|
* **IntegrationTestBaselineGenerator**: Run *manually* to generate and save "expected results" for comparing in the future.
|
||||||
|
Output goes to dl4j-test-resources, for saving/uploading.
|
||||||
|
* **IntegrationTestRunner**: Actually runs the tests, and compares the output/result to those generated by the baseline generator
|
||||||
|
* **TestCase**: integration tests extend this
|
||||||
|
* **testcases/\*.java**: the actual integration test definitions
|
||||||
|
* **IntegrationTestsDL4J**: entry point for running the DL4J integration tests
|
||||||
|
* **IntegrationTestsSameDiff**: entry point for running the SameDiff integration tests
|
||||||
|
|
||||||
|
## Types of Test Components
|
||||||
|
|
||||||
|
The integration tests are set up to be able to run multiple types of tests on each network configuration.
|
||||||
|
|
||||||
Networks may be pretrained (from model zoo) or randomly initialized (from specified configuration).
|
Networks may be pretrained (from model zoo) or randomly initialized (from specified configuration).
|
||||||
|
|
||||||
Specifically, test cases can be run with any subset of the following components to be tested, by setting TestCase.XYZ boolean options to true or false:
|
Specifically, test cases can be run with any subset of the following components to be tested, by setting TestCase.XYZ boolean options to true or false:
|
||||||
|
|
||||||
1. testPredictions: Testing output (predictions) on some specified data vs. saved/known good arrays
|
1. **testPredictions**: Testing output (predictions) on some specified data vs. saved/known good arrays
|
||||||
2. testGradients: Testing gradients on some specified data vs. saved/known good arrays
|
2. **testGradients**: Testing gradients on some specified data vs. saved/known good arrays
|
||||||
3. testPretrain: Test layerwise pretraining parameters and training curves
|
3. **testPretrain**: Test layerwise pretraining parameters and training curves
|
||||||
4. testTrainingCurves: Train, and check score vs. iteration
|
4. **testTrainingCurves**: Train, and check score vs. iteration
|
||||||
5. testParamsPostTraining: validate params match post training
|
5. **testParamsPostTraining**: validate params match post training
|
||||||
6. testEvaluation: test the evaluation performance (post training, if 4 or 5 are true)
|
6. **testEvaluation**: test the evaluation performance (post training, if 4 or 5 are true)
|
||||||
7. testParallelInference: validate that single net and parallel inference results match
|
7. **testParallelInference**: validate that single net and parallel inference results match
|
||||||
8. testOverfitting: sanity check - try to overfit a single example
|
8. **testOverfitting**: sanity check - try to overfit a single example
|
||||||
|
|
||||||
|
See TestCase.java for more details.
|
||||||
|
|
||||||
|
|
||||||
## Adding a New Integration Test
|
## Adding a New Integration Test
|
||||||
|
|
||||||
The process to add a new test is simple:
|
The process to add a new test is simple:
|
||||||
1. Add a method that creates and returns a TestCase object
|
1. Add a method that creates and returns a TestCase object (example: testcases/MLPTestCases.getMLPMnist())
|
||||||
2. Add it as a unit test to IntegrationTests class
|
2. Add it as a unit test to IntegrationTests class (example: IntegrationTestsDL4J.testMLPMnist())
|
||||||
3. Run IntegrationTestBaselineGenerator (if required) to generate and save the "known good" results.
|
3. Run IntegrationTestBaselineGenerator with the new test case, to generate and save the "known good" results.
|
||||||
|
4. Run the new integration test to make sure it passes, on both CPU and CUDA backends
|
||||||
|
5. Commit the generated test resources from step 3 to dl4j-test-resources repo
|
||||||
|
|
||||||
Note that IntegrationTestBaselineGenerator assumes you have the dl4j-test-resources cloned parallel to the DL4J mono-repo.
|
Note that IntegrationTestBaselineGenerator assumes you have the dl4j-test-resources cloned parallel to the DL4J mono-repo.
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -16,15 +17,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.integration;
|
package org.deeplearning4j.integration;
|
||||||
|
|
||||||
import org.nd4j.shade.guava.io.Files;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
||||||
import org.deeplearning4j.eval.IEvaluation;
|
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
|
||||||
import org.deeplearning4j.integration.testcases.CNN2DTestCases;
|
|
||||||
import org.deeplearning4j.integration.testcases.MLPTestCases;
|
|
||||||
import org.deeplearning4j.integration.testcases.RNNTestCases;
|
|
||||||
import org.deeplearning4j.integration.testcases.UnsupervisedTestCases;
|
|
||||||
import org.deeplearning4j.nn.api.Model;
|
import org.deeplearning4j.nn.api.Model;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
@ -32,20 +28,27 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.optimize.listeners.CollectScoresListener;
|
import org.deeplearning4j.optimize.listeners.CollectScoresListener;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
|
import org.nd4j.autodiff.listeners.records.History;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.autodiff.samediff.VariableType;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
import org.nd4j.shade.guava.io.Files;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Run this manually to generate - or update - the saved files for a specific test.
|
* Run this manually to generate - or update - the saved files for a specific test.
|
||||||
* Places results in dl4j-test-resources: assumes you have the dl4j-test-resources cloned parallel to the DL4J mono-repo.
|
* Places results in dl4j-test-resources: assumes you have the dl4j-test-resources cloned parallel to the DL4J mono-repo.
|
||||||
|
@ -53,32 +56,31 @@ import java.util.stream.Collectors;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class IntegrationTestBaselineGenerator {
|
public class IntegrationTestBaselineGenerator {
|
||||||
|
|
||||||
public static final File OUTPUT_DIR = new File("../../dl4j-test-resources/src/main/resources/dl4j-integration-tests").getAbsoluteFile();
|
public static final File OUTPUT_DIR_DL4J = new File("../../dl4j-test-resources/src/main/resources/dl4j-integration-tests").getAbsoluteFile();
|
||||||
|
public static final File OUTPUT_DIR_SAMEDIFF = new File("../../dl4j-test-resources/src/main/resources/samediff-integration-tests").getAbsoluteFile();
|
||||||
|
|
||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args) throws Exception {
|
||||||
if (!OUTPUT_DIR.exists()) {
|
if (!OUTPUT_DIR_DL4J.exists() && !OUTPUT_DIR_SAMEDIFF.exists()) {
|
||||||
throw new RuntimeException("output directory (test resources) does not exist!");
|
throw new RuntimeException("output directories in test resources do not exist!");
|
||||||
}
|
}
|
||||||
|
|
||||||
//All integration tests are run with float precision!
|
runGeneration(
|
||||||
Nd4j.setDataType(DataType.FLOAT);
|
SameDiffMLPTestCases.getMLPMnist()
|
||||||
|
);
|
||||||
// runGeneration(
|
|
||||||
// MLPTestCases.getMLPMnist(),
|
|
||||||
// );
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void runGeneration(TestCase... testCases) throws Exception {
|
private static void runGeneration(TestCase... testCases) throws Exception {
|
||||||
|
|
||||||
for( TestCase tc : testCases ) {
|
for( TestCase tc : testCases ) {
|
||||||
|
final ModelType modelType = tc.modelType();
|
||||||
|
|
||||||
//Basic validation:
|
//Basic validation:
|
||||||
Preconditions.checkState(tc.getTestName() != null, "Test case name is null");
|
Preconditions.checkState(tc.getTestName() != null, "Test case name is null");
|
||||||
|
|
||||||
//Run through each test case:
|
//Run through each test case:
|
||||||
File testBaseDir = new File(OUTPUT_DIR, tc.getTestName());
|
File testBaseDir = new File(modelType == ModelType.SAMEDIFF ? OUTPUT_DIR_SAMEDIFF : OUTPUT_DIR_DL4J, tc.getTestName());
|
||||||
if (testBaseDir.exists()) {
|
if (testBaseDir.exists()) {
|
||||||
FileUtils.forceDelete(testBaseDir);
|
FileUtils.forceDelete(testBaseDir);
|
||||||
}
|
}
|
||||||
|
@ -109,56 +111,62 @@ public class IntegrationTestBaselineGenerator {
|
||||||
//First: if test is a random init test: generate the config, and save it
|
//First: if test is a random init test: generate the config, and save it
|
||||||
MultiLayerNetwork mln = null;
|
MultiLayerNetwork mln = null;
|
||||||
ComputationGraph cg = null;
|
ComputationGraph cg = null;
|
||||||
Model m;
|
SameDiff sd = null;
|
||||||
boolean isMLN;
|
Model m = null;
|
||||||
if (tc.getTestType() == TestCase.TestType.RANDOM_INIT) {
|
if (tc.getTestType() == TestCase.TestType.RANDOM_INIT) {
|
||||||
Object config = tc.getConfiguration();
|
Object config = tc.getConfiguration();
|
||||||
String json;
|
String json = null;
|
||||||
if (config instanceof MultiLayerConfiguration) {
|
if (config instanceof MultiLayerConfiguration) {
|
||||||
MultiLayerConfiguration mlc = (MultiLayerConfiguration) config;
|
MultiLayerConfiguration mlc = (MultiLayerConfiguration) config;
|
||||||
isMLN = true;
|
|
||||||
json = mlc.toJson();
|
json = mlc.toJson();
|
||||||
mln = new MultiLayerNetwork(mlc);
|
mln = new MultiLayerNetwork(mlc);
|
||||||
mln.init();
|
mln.init();
|
||||||
m = mln;
|
m = mln;
|
||||||
} else {
|
} else if (config instanceof ComputationGraphConfiguration){
|
||||||
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
|
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
|
||||||
isMLN = false;
|
|
||||||
json = cgc.toJson();
|
json = cgc.toJson();
|
||||||
cg = new ComputationGraph(cgc);
|
cg = new ComputationGraph(cgc);
|
||||||
cg.init();
|
cg.init();
|
||||||
m = cg;
|
m = cg;
|
||||||
|
} else {
|
||||||
|
sd = (SameDiff)config;
|
||||||
}
|
}
|
||||||
|
|
||||||
File configFile = new File(testBaseDir, "config." + (isMLN ? "mlc.json" : "cgc.json"));
|
|
||||||
FileUtils.writeStringToFile(configFile, json);
|
|
||||||
log.info("RANDOM_INIT test - saved configuration: {}", configFile.getAbsolutePath());
|
|
||||||
File savedModel = new File(testBaseDir, IntegrationTestRunner.RANDOM_INIT_UNTRAINED_MODEL_FILENAME);
|
File savedModel = new File(testBaseDir, IntegrationTestRunner.RANDOM_INIT_UNTRAINED_MODEL_FILENAME);
|
||||||
|
if(modelType != ModelType.SAMEDIFF) {
|
||||||
|
File configFile = new File(testBaseDir, "config." + (modelType == ModelType.MLN ? "mlc.json" : "cgc.json"));
|
||||||
|
FileUtils.writeStringToFile(configFile, json, StandardCharsets.UTF_8);
|
||||||
|
log.info("RANDOM_INIT test - saved configuration: {}", configFile.getAbsolutePath());
|
||||||
ModelSerializer.writeModel(m, savedModel, true);
|
ModelSerializer.writeModel(m, savedModel, true);
|
||||||
|
} else {
|
||||||
|
sd.save(savedModel, true);
|
||||||
|
}
|
||||||
log.info("RANDOM_INIT test - saved randomly initialized model to: {}", savedModel.getAbsolutePath());
|
log.info("RANDOM_INIT test - saved randomly initialized model to: {}", savedModel.getAbsolutePath());
|
||||||
} else {
|
} else {
|
||||||
//Pretrained model
|
//Pretrained model
|
||||||
m = tc.getPretrainedModel();
|
m = tc.getPretrainedModel();
|
||||||
isMLN = (m instanceof MultiLayerNetwork);
|
if (m instanceof MultiLayerNetwork) {
|
||||||
if (isMLN) {
|
|
||||||
mln = (MultiLayerNetwork) m;
|
mln = (MultiLayerNetwork) m;
|
||||||
} else {
|
} else if(m instanceof ComputationGraph){
|
||||||
cg = (ComputationGraph) m;
|
cg = (ComputationGraph) m;
|
||||||
|
} else {
|
||||||
|
sd = (SameDiff)m;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//Generate predictions to compare against
|
//Generate predictions to compare against
|
||||||
if (tc.isTestPredictions()) {
|
if (tc.isTestPredictions()) {
|
||||||
List<Pair<INDArray[], INDArray[]>> inputs = tc.getPredictionsTestData();
|
List<Pair<INDArray[], INDArray[]>> inputs = modelType != ModelType.SAMEDIFF ? tc.getPredictionsTestData() : null;
|
||||||
Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
|
List<Map<String,INDArray>> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null;
|
||||||
|
// Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
|
||||||
|
|
||||||
|
|
||||||
File predictionsTestDir = new File(testBaseDir, "predictions");
|
File predictionsTestDir = new File(testBaseDir, "predictions");
|
||||||
predictionsTestDir.mkdirs();
|
predictionsTestDir.mkdirs();
|
||||||
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
if (isMLN) {
|
if (modelType == ModelType.MLN) {
|
||||||
for (Pair<INDArray[], INDArray[]> p : inputs) {
|
for (Pair<INDArray[], INDArray[]> p : inputs) {
|
||||||
INDArray f = p.getFirst()[0];
|
INDArray f = p.getFirst()[0];
|
||||||
INDArray fm = (p.getSecond() == null ? null : p.getSecond()[0]);
|
INDArray fm = (p.getSecond() == null ? null : p.getSecond()[0]);
|
||||||
|
@ -170,7 +178,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
Nd4j.write(out, dos);
|
Nd4j.write(out, dos);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if(modelType == ModelType.CG) {
|
||||||
for (Pair<INDArray[], INDArray[]> p : inputs) {
|
for (Pair<INDArray[], INDArray[]> p : inputs) {
|
||||||
INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null);
|
INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null);
|
||||||
|
|
||||||
|
@ -182,6 +190,19 @@ public class IntegrationTestBaselineGenerator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
List<String> outNames = tc.getPredictionsNamesSameDiff();
|
||||||
|
for( Map<String,INDArray> ph : inputsSd ){
|
||||||
|
Map<String,INDArray> out = sd.output(ph, outNames);
|
||||||
|
|
||||||
|
//Save the output...
|
||||||
|
for(String s : outNames){
|
||||||
|
File f = new File(predictionsTestDir, "output_" + (count++) + "_" + s + ".bin");
|
||||||
|
try (DataOutputStream dos = new DataOutputStream(new FileOutputStream(f))) {
|
||||||
|
Nd4j.write(out.get(s), dos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.info("Saved predictions for {} inputs to disk in directory: {}", tc.getTestName(), predictionsTestDir);
|
log.info("Saved predictions for {} inputs to disk in directory: {}", tc.getTestName(), predictionsTestDir);
|
||||||
|
@ -189,32 +210,46 @@ public class IntegrationTestBaselineGenerator {
|
||||||
|
|
||||||
//Compute and save gradients:
|
//Compute and save gradients:
|
||||||
if (tc.isTestGradients()) {
|
if (tc.isTestGradients()) {
|
||||||
|
INDArray gradientFlat = null;
|
||||||
|
Map<String,INDArray> grad;
|
||||||
|
if (modelType == ModelType.MLN) {
|
||||||
MultiDataSet data = tc.getGradientsTestData();
|
MultiDataSet data = tc.getGradientsTestData();
|
||||||
INDArray gradientFlat;
|
|
||||||
if (isMLN) {
|
|
||||||
mln.setInput(data.getFeatures(0));
|
mln.setInput(data.getFeatures(0));
|
||||||
mln.setLabels(data.getLabels(0));
|
mln.setLabels(data.getLabels(0));
|
||||||
mln.setLayerMaskArrays(data.getFeaturesMaskArray(0), data.getLabelsMaskArray(0));
|
mln.setLayerMaskArrays(data.getFeaturesMaskArray(0), data.getLabelsMaskArray(0));
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
gradientFlat = mln.getFlattenedGradients();
|
gradientFlat = mln.getFlattenedGradients();
|
||||||
} else {
|
grad = m.gradient().gradientForVariable();
|
||||||
|
} else if(modelType == ModelType.CG) {
|
||||||
|
MultiDataSet data = tc.getGradientsTestData();
|
||||||
cg.setInputs(data.getFeatures());
|
cg.setInputs(data.getFeatures());
|
||||||
cg.setLabels(data.getLabels());
|
cg.setLabels(data.getLabels());
|
||||||
cg.setLayerMaskArrays(data.getFeaturesMaskArrays(), data.getLabelsMaskArrays());
|
cg.setLayerMaskArrays(data.getFeaturesMaskArrays(), data.getLabelsMaskArrays());
|
||||||
cg.computeGradientAndScore();
|
cg.computeGradientAndScore();
|
||||||
gradientFlat = cg.getFlattenedGradients();
|
gradientFlat = cg.getFlattenedGradients();
|
||||||
|
grad = m.gradient().gradientForVariable();
|
||||||
|
} else {
|
||||||
|
Map<String,INDArray> ph = tc.getGradientsTestDataSameDiff();
|
||||||
|
List<String> allVars = new ArrayList<>();
|
||||||
|
for(SDVariable v : sd.variables()){
|
||||||
|
if(v.getVariableType() == VariableType.VARIABLE){
|
||||||
|
allVars.add(v.name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
grad = sd.calculateGradients(ph, allVars);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(modelType != ModelType.SAMEDIFF) {
|
||||||
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
|
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
|
||||||
IntegrationTestRunner.write(gradientFlat, gFlatFile);
|
IntegrationTestRunner.write(gradientFlat, gFlatFile);
|
||||||
|
}
|
||||||
|
|
||||||
//Also save the gradient param table:
|
//Also save the gradient param table:
|
||||||
Map<String, INDArray> g = m.gradient().gradientForVariable();
|
|
||||||
File gradientDir = new File(testBaseDir, "gradients");
|
File gradientDir = new File(testBaseDir, "gradients");
|
||||||
gradientDir.mkdir();
|
gradientDir.mkdir();
|
||||||
for (String s : g.keySet()) {
|
for (String s : grad.keySet()) {
|
||||||
File f = new File(gradientDir, s + ".bin");
|
File f = new File(gradientDir, s + ".bin");
|
||||||
IntegrationTestRunner.write(g.get(s), f);
|
IntegrationTestRunner.write(grad.get(s), f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -224,7 +259,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
MultiDataSetIterator iter = tc.getUnsupervisedTrainData();
|
MultiDataSetIterator iter = tc.getUnsupervisedTrainData();
|
||||||
|
|
||||||
INDArray paramsPostTraining;
|
INDArray paramsPostTraining;
|
||||||
if(isMLN){
|
if(modelType == ModelType.MLN){
|
||||||
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
|
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
|
||||||
Preconditions.checkState(layersToTrain != null, "Layer indices must not be null");
|
Preconditions.checkState(layersToTrain != null, "Layer indices must not be null");
|
||||||
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
||||||
|
@ -233,7 +268,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
mln.pretrainLayer(i, dsi);
|
mln.pretrainLayer(i, dsi);
|
||||||
}
|
}
|
||||||
paramsPostTraining = mln.params();
|
paramsPostTraining = mln.params();
|
||||||
} else {
|
} else if(modelType == ModelType.CG) {
|
||||||
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
||||||
Preconditions.checkState(layersToTrain != null, "Layer names must not be null");
|
Preconditions.checkState(layersToTrain != null, "Layer names must not be null");
|
||||||
|
|
||||||
|
@ -241,6 +276,8 @@ public class IntegrationTestBaselineGenerator {
|
||||||
cg.pretrainLayer(i, iter);
|
cg.pretrainLayer(i, iter);
|
||||||
}
|
}
|
||||||
paramsPostTraining = cg.params();
|
paramsPostTraining = cg.params();
|
||||||
|
} else {
|
||||||
|
throw new UnsupportedOperationException("SameDiff not supported for unsupervised training tests");
|
||||||
}
|
}
|
||||||
|
|
||||||
//Save params
|
//Save params
|
||||||
|
@ -251,36 +288,61 @@ public class IntegrationTestBaselineGenerator {
|
||||||
//Test training curves:
|
//Test training curves:
|
||||||
if (tc.isTestTrainingCurves()) {
|
if (tc.isTestTrainingCurves()) {
|
||||||
MultiDataSetIterator trainData = tc.getTrainingData();
|
MultiDataSetIterator trainData = tc.getTrainingData();
|
||||||
|
|
||||||
CollectScoresListener l = new CollectScoresListener(1);
|
CollectScoresListener l = new CollectScoresListener(1);
|
||||||
|
if(modelType != ModelType.SAMEDIFF)
|
||||||
m.setListeners(l);
|
m.setListeners(l);
|
||||||
|
|
||||||
if (isMLN) {
|
History h = null;
|
||||||
|
if (modelType == ModelType.MLN) {
|
||||||
mln.fit(trainData);
|
mln.fit(trainData);
|
||||||
} else {
|
} else if(modelType == ModelType.CG) {
|
||||||
cg.fit(trainData);
|
cg.fit(trainData);
|
||||||
|
} else {
|
||||||
|
h = sd.fit(trainData, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
double[] scores;
|
||||||
|
if(modelType != ModelType.SAMEDIFF){
|
||||||
|
scores = l.getListScore().toDoubleArray();
|
||||||
|
} else {
|
||||||
|
scores = h.lossCurve().getLossValues().toDoubleVector();
|
||||||
}
|
}
|
||||||
|
|
||||||
double[] scores = l.getListScore().toDoubleArray();
|
|
||||||
File f = new File(testBaseDir, IntegrationTestRunner.TRAINING_CURVE_FILENAME);
|
File f = new File(testBaseDir, IntegrationTestRunner.TRAINING_CURVE_FILENAME);
|
||||||
List<String> s = Arrays.stream(scores).mapToObj(String::valueOf).collect(Collectors.toList());
|
List<String> s = Arrays.stream(scores).mapToObj(String::valueOf).collect(Collectors.toList());
|
||||||
FileUtils.writeStringToFile(f, String.join(",", s));
|
FileUtils.writeStringToFile(f, String.join(",", s), StandardCharsets.UTF_8);
|
||||||
|
|
||||||
if (tc.isTestParamsPostTraining()) {
|
if (tc.isTestParamsPostTraining()) {
|
||||||
|
if(modelType == ModelType.SAMEDIFF){
|
||||||
|
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_SAMEDIFF_DIR);
|
||||||
|
p.mkdirs();
|
||||||
|
for(SDVariable v : sd.variables()){
|
||||||
|
if(v.getVariableType() == VariableType.VARIABLE){
|
||||||
|
INDArray arr = v.getArr();
|
||||||
|
File p2 = new File(p, v.name() + ".bin");
|
||||||
|
IntegrationTestRunner.write(arr, p2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
|
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
|
||||||
IntegrationTestRunner.write(m.params(), p);
|
IntegrationTestRunner.write(m.params(), p);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if (tc.isTestEvaluation()) {
|
if (tc.isTestEvaluation()) {
|
||||||
IEvaluation[] evals = tc.getNewEvaluations();
|
IEvaluation[] evals = tc.getNewEvaluations();
|
||||||
MultiDataSetIterator iter = tc.getEvaluationTestData();
|
MultiDataSetIterator iter = tc.getEvaluationTestData();
|
||||||
|
|
||||||
if (isMLN) {
|
if (modelType == ModelType.MLN) {
|
||||||
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
||||||
mln.doEvaluation(dsi, evals);
|
mln.doEvaluation(dsi, evals);
|
||||||
} else {
|
} else if(modelType == ModelType.CG){
|
||||||
cg.doEvaluation(iter, evals);
|
cg.doEvaluation(iter, evals);
|
||||||
|
} else {
|
||||||
|
evals = tc.doEvaluationSameDiff(sd, iter, evals);
|
||||||
}
|
}
|
||||||
|
|
||||||
File evalDir = new File(testBaseDir, "evaluation");
|
File evalDir = new File(testBaseDir, "evaluation");
|
||||||
|
@ -288,7 +350,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
for (int i = 0; i < evals.length; i++) {
|
for (int i = 0; i < evals.length; i++) {
|
||||||
String json = evals[i].toJson();
|
String json = evals[i].toJson();
|
||||||
File f = new File(evalDir, i + "." + evals[i].getClass().getSimpleName() + ".json");
|
File f = new File(evalDir, i + "." + evals[i].getClass().getSimpleName() + ".json");
|
||||||
FileUtils.writeStringToFile(f, json);
|
FileUtils.writeStringToFile(f, json, StandardCharsets.UTF_8);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -17,14 +18,12 @@
|
||||||
package org.deeplearning4j.integration;
|
package org.deeplearning4j.integration;
|
||||||
|
|
||||||
|
|
||||||
import org.nd4j.shade.guava.collect.ImmutableSet;
|
|
||||||
import org.nd4j.shade.guava.reflect.ClassPath;
|
|
||||||
import org.deeplearning4j.integration.util.CountingMultiDataSetIterator;
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
||||||
import org.deeplearning4j.eval.*;
|
import org.deeplearning4j.integration.util.CountingMultiDataSetIterator;
|
||||||
import org.deeplearning4j.nn.api.Model;
|
import org.deeplearning4j.nn.api.Model;
|
||||||
import org.deeplearning4j.nn.conf.BackpropType;
|
import org.deeplearning4j.nn.conf.BackpropType;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
|
@ -42,9 +41,16 @@ import org.deeplearning4j.parallelism.ParallelInference;
|
||||||
import org.deeplearning4j.parallelism.inference.InferenceMode;
|
import org.deeplearning4j.parallelism.inference.InferenceMode;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
import org.nd4j.autodiff.listeners.records.History;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.autodiff.samediff.VariableType;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
|
import org.nd4j.evaluation.classification.*;
|
||||||
|
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||||
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||||
|
@ -55,12 +61,15 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
import org.nd4j.resources.Resources;
|
||||||
|
import org.nd4j.shade.guava.collect.ImmutableSet;
|
||||||
|
import org.nd4j.shade.guava.reflect.ClassPath;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.lang.reflect.Modifier;
|
import java.lang.reflect.Modifier;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
@ -79,6 +88,7 @@ public class IntegrationTestRunner {
|
||||||
public static final String FLAT_GRADIENTS_FILENAME = "flattenedGradients.bin";
|
public static final String FLAT_GRADIENTS_FILENAME = "flattenedGradients.bin";
|
||||||
public static final String TRAINING_CURVE_FILENAME = "trainingCurve.csv";
|
public static final String TRAINING_CURVE_FILENAME = "trainingCurve.csv";
|
||||||
public static final String PARAMS_POST_TRAIN_FILENAME = "paramsPostTrain.bin";
|
public static final String PARAMS_POST_TRAIN_FILENAME = "paramsPostTrain.bin";
|
||||||
|
public static final String PARAMS_POST_TRAIN_SAMEDIFF_DIR = "paramsPostTrain";
|
||||||
public static final String PARAMS_POST_UNSUPERVISED_FILENAME = "paramsPostUnsupervised.bin";
|
public static final String PARAMS_POST_UNSUPERVISED_FILENAME = "paramsPostUnsupervised.bin";
|
||||||
|
|
||||||
public static final double MAX_REL_ERROR_SCORES = 1e-4;
|
public static final double MAX_REL_ERROR_SCORES = 1e-4;
|
||||||
|
@ -148,21 +158,25 @@ public class IntegrationTestRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void runTest(TestCase tc, TemporaryFolder testDir) throws Exception {
|
public static void runTest(TestCase tc, TemporaryFolder testDir) throws Exception {
|
||||||
Preconditions.checkState(Nd4j.dataType() == DataType.FLOAT, "Integration tests must be run with float precision!");
|
BaseDL4JTest.skipUnlessIntegrationTests(); //Tests will ONLY be run if integration test profile is enabled.
|
||||||
log.info("Starting test case: {}", tc.getTestName());
|
//This could alternatively be done via maven surefire configuration
|
||||||
|
|
||||||
|
final ModelType modelType = tc.modelType();
|
||||||
|
log.info("Starting test case: {} - type = {}", tc.getTestName(), modelType);
|
||||||
long start = System.currentTimeMillis();
|
long start = System.currentTimeMillis();
|
||||||
|
|
||||||
File workingDir = testDir.newFolder();
|
File workingDir = testDir.newFolder();
|
||||||
tc.initialize(workingDir);
|
tc.initialize(workingDir);
|
||||||
|
|
||||||
File testBaseDir = testDir.newFolder();
|
File testBaseDir = testDir.newFolder();
|
||||||
new ClassPathResource("dl4j-integration-tests/" + tc.getTestName()).copyDirectory(testBaseDir);
|
// new ClassPathResource("dl4j-integration-tests/" + tc.getTestName()).copyDirectory(testBaseDir);
|
||||||
|
Resources.copyDirectory((modelType == ModelType.SAMEDIFF ? "samediff-integration-tests/" : "dl4j-integration-tests/") + tc.getTestName(), testBaseDir);
|
||||||
|
|
||||||
|
|
||||||
MultiLayerNetwork mln = null;
|
MultiLayerNetwork mln = null;
|
||||||
ComputationGraph cg = null;
|
ComputationGraph cg = null;
|
||||||
Model m;
|
SameDiff sd = null;
|
||||||
boolean isMLN;
|
Model m = null;
|
||||||
if (tc.getTestType() == TestCase.TestType.RANDOM_INIT) {
|
if (tc.getTestType() == TestCase.TestType.RANDOM_INIT) {
|
||||||
log.info("Checking RANDOM_INIT test case: saved model vs. initialized model");
|
log.info("Checking RANDOM_INIT test case: saved model vs. initialized model");
|
||||||
//Checking randomly initialized model:
|
//Checking randomly initialized model:
|
||||||
|
@ -173,36 +187,46 @@ public class IntegrationTestRunner {
|
||||||
mln = new MultiLayerNetwork(mlc);
|
mln = new MultiLayerNetwork(mlc);
|
||||||
mln.init();
|
mln.init();
|
||||||
m = mln;
|
m = mln;
|
||||||
isMLN = true;
|
|
||||||
|
|
||||||
MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true);
|
MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true);
|
||||||
assertEquals("Configs not equal", loaded.getLayerWiseConfigurations(), mln.getLayerWiseConfigurations());
|
assertEquals("Configs not equal", loaded.getLayerWiseConfigurations(), mln.getLayerWiseConfigurations());
|
||||||
assertEquals("Params not equal", loaded.params(), mln.params());
|
assertEquals("Params not equal", loaded.params(), mln.params());
|
||||||
assertEquals("Param table not equal", loaded.paramTable(), mln.paramTable());
|
assertEquals("Param table not equal", loaded.paramTable(), mln.paramTable());
|
||||||
} else {
|
} else if(config instanceof ComputationGraphConfiguration ){
|
||||||
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
|
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
|
||||||
cg = new ComputationGraph(cgc);
|
cg = new ComputationGraph(cgc);
|
||||||
cg.init();
|
cg.init();
|
||||||
m = cg;
|
m = cg;
|
||||||
isMLN = false;
|
|
||||||
|
|
||||||
ComputationGraph loaded = ComputationGraph.load(savedModel, true);
|
ComputationGraph loaded = ComputationGraph.load(savedModel, true);
|
||||||
assertEquals("Configs not equal", loaded.getConfiguration(), cg.getConfiguration());
|
assertEquals("Configs not equal", loaded.getConfiguration(), cg.getConfiguration());
|
||||||
assertEquals("Params not equal", loaded.params(), cg.params());
|
assertEquals("Params not equal", loaded.params(), cg.params());
|
||||||
assertEquals("Param table not equal", loaded.paramTable(), cg.paramTable());
|
assertEquals("Param table not equal", loaded.paramTable(), cg.paramTable());
|
||||||
|
} else if(config instanceof SameDiff){
|
||||||
|
sd = (SameDiff)config;
|
||||||
|
SameDiff loaded = SameDiff.load(savedModel, true);
|
||||||
|
|
||||||
|
assertSameDiffEquals(sd, loaded);
|
||||||
|
} else {
|
||||||
|
throw new IllegalStateException("Unknown configuration/model type: " + config.getClass());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
m = tc.getPretrainedModel();
|
m = tc.getPretrainedModel();
|
||||||
isMLN = (m instanceof MultiLayerNetwork);
|
if (m instanceof MultiLayerNetwork) {
|
||||||
if (isMLN) {
|
|
||||||
mln = (MultiLayerNetwork) m;
|
mln = (MultiLayerNetwork) m;
|
||||||
} else {
|
} else if(m instanceof ComputationGraph) {
|
||||||
cg = (ComputationGraph) m;
|
cg = (ComputationGraph) m;
|
||||||
|
} else if(m instanceof SameDiff){
|
||||||
|
sd = (SameDiff)m;
|
||||||
|
} else {
|
||||||
|
throw new IllegalStateException("Unknown model type: " + m.getClass());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//Collect information for test coverage
|
//Collect information for test coverage
|
||||||
|
if(modelType != ModelType.SAMEDIFF) {
|
||||||
collectCoverageInformation(m);
|
collectCoverageInformation(m);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//Check network output (predictions)
|
//Check network output (predictions)
|
||||||
|
@ -210,15 +234,16 @@ public class IntegrationTestRunner {
|
||||||
log.info("Checking predictions: saved output vs. initialized model");
|
log.info("Checking predictions: saved output vs. initialized model");
|
||||||
|
|
||||||
|
|
||||||
List<Pair<INDArray[], INDArray[]>> inputs = tc.getPredictionsTestData();
|
List<Pair<INDArray[], INDArray[]>> inputs = modelType != ModelType.SAMEDIFF ? tc.getPredictionsTestData() : null;
|
||||||
Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
|
List<Map<String,INDArray>> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null;
|
||||||
|
Preconditions.checkState(modelType == ModelType.SAMEDIFF || inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
|
||||||
|
|
||||||
|
|
||||||
File predictionsTestDir = new File(testBaseDir, "predictions");
|
File predictionsTestDir = new File(testBaseDir, "predictions");
|
||||||
predictionsTestDir.mkdirs();
|
predictionsTestDir.mkdirs();
|
||||||
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
if (isMLN) {
|
if (modelType == ModelType.MLN) {
|
||||||
for (Pair<INDArray[], INDArray[]> p : inputs) {
|
for (Pair<INDArray[], INDArray[]> p : inputs) {
|
||||||
INDArray f = p.getFirst()[0];
|
INDArray f = p.getFirst()[0];
|
||||||
INDArray fm = (p.getSecond() == null ? null : p.getSecond()[0]);
|
INDArray fm = (p.getSecond() == null ? null : p.getSecond()[0]);
|
||||||
|
@ -231,15 +256,15 @@ public class IntegrationTestRunner {
|
||||||
outSaved = Nd4j.read(dis);
|
outSaved = Nd4j.read(dis);
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray gradExceedsRE = exceedsRelError(outSaved, out, tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
|
INDArray predictionExceedsRE = exceedsRelError(outSaved, out, tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
|
||||||
int countExceeds = gradExceedsRE.sumNumber().intValue();
|
int countExceeds = predictionExceedsRE.sumNumber().intValue();
|
||||||
assertEquals("Predictions do not match saved predictions - output", 0, countExceeds);
|
assertEquals("Predictions do not match saved predictions - output", 0, countExceeds);
|
||||||
}
|
}
|
||||||
} else {
|
} else if(modelType == ModelType.CG){
|
||||||
for (Pair<INDArray[], INDArray[]> p : inputs) {
|
for (Pair<INDArray[], INDArray[]> p : inputs) {
|
||||||
INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null);
|
INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null);
|
||||||
|
|
||||||
//Save the array(s)...
|
//Load the previously saved arrays
|
||||||
INDArray[] outSaved = new INDArray[out.length];
|
INDArray[] outSaved = new INDArray[out.length];
|
||||||
for (int i = 0; i < out.length; i++) {
|
for (int i = 0; i < out.length; i++) {
|
||||||
File outFile = new File(predictionsTestDir, "output_" + (count++) + "_" + i + ".bin");
|
File outFile = new File(predictionsTestDir, "output_" + (count++) + "_" + i + ".bin");
|
||||||
|
@ -249,49 +274,86 @@ public class IntegrationTestRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
for( int i=0; i<outSaved.length; i++ ){
|
for( int i=0; i<outSaved.length; i++ ){
|
||||||
INDArray gradExceedsRE = exceedsRelError(outSaved[i], out[i], tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
|
INDArray predictionExceedsRE = exceedsRelError(outSaved[i], out[i], tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
|
||||||
int countExceeds = gradExceedsRE.sumNumber().intValue();
|
int countExceeds = predictionExceedsRE.sumNumber().intValue();
|
||||||
assertEquals("Predictions do not match saved predictions - output " + i, 0, countExceeds);
|
assertEquals("Predictions do not match saved predictions - output " + i, 0, countExceeds);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
List<String> outNames = tc.getPredictionsNamesSameDiff();
|
||||||
|
for( Map<String,INDArray> ph : inputsSd ){
|
||||||
|
Map<String,INDArray> out = sd.output(ph, outNames);
|
||||||
|
|
||||||
|
//Load the previously saved placeholder arrays
|
||||||
|
Map<String,INDArray> outSaved = new HashMap<>();
|
||||||
|
for(String s : outNames){
|
||||||
|
File f = new File(predictionsTestDir, "output_" + (count++) + "_" + s + ".bin");
|
||||||
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f))) {
|
||||||
|
outSaved.put(s, Nd4j.read(dis));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for(String s : outNames){
|
||||||
|
INDArray predictionExceedsRE = exceedsRelError(outSaved.get(s), out.get(s), tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
|
||||||
|
int countExceeds = predictionExceedsRE.sumNumber().intValue();
|
||||||
|
assertEquals("Predictions do not match saved predictions - output \"" + s + "\"", 0, countExceeds);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(modelType != ModelType.SAMEDIFF) {
|
||||||
checkLayerClearance(m);
|
checkLayerClearance(m);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//Test gradients
|
//Test gradients
|
||||||
if (tc.isTestGradients()) {
|
if (tc.isTestGradients()) {
|
||||||
log.info("Checking gradients: saved output vs. initialized model");
|
log.info("Checking gradients: saved output vs. initialized model");
|
||||||
|
|
||||||
|
INDArray gradientFlat = null;
|
||||||
|
org.deeplearning4j.nn.api.Layer[] layers = null;
|
||||||
|
Map<String,INDArray> grad;
|
||||||
|
if (modelType == ModelType.MLN) {
|
||||||
MultiDataSet data = tc.getGradientsTestData();
|
MultiDataSet data = tc.getGradientsTestData();
|
||||||
INDArray gradientFlat;
|
|
||||||
org.deeplearning4j.nn.api.Layer[] layers;
|
|
||||||
if (isMLN) {
|
|
||||||
mln.setInput(data.getFeatures(0));
|
mln.setInput(data.getFeatures(0));
|
||||||
mln.setLabels(data.getLabels(0));
|
mln.setLabels(data.getLabels(0));
|
||||||
mln.setLayerMaskArrays(data.getFeaturesMaskArray(0), data.getLabelsMaskArray(0));
|
mln.setLayerMaskArrays(data.getFeaturesMaskArray(0), data.getLabelsMaskArray(0));
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
gradientFlat = mln.getFlattenedGradients();
|
gradientFlat = mln.getFlattenedGradients();
|
||||||
layers = mln.getLayers();
|
layers = mln.getLayers();
|
||||||
} else {
|
grad = mln.gradient().gradientForVariable();
|
||||||
|
} else if(modelType == ModelType.CG) {
|
||||||
|
MultiDataSet data = tc.getGradientsTestData();
|
||||||
cg.setInputs(data.getFeatures());
|
cg.setInputs(data.getFeatures());
|
||||||
cg.setLabels(data.getLabels());
|
cg.setLabels(data.getLabels());
|
||||||
cg.setLayerMaskArrays(data.getFeaturesMaskArrays(), data.getLabelsMaskArrays());
|
cg.setLayerMaskArrays(data.getFeaturesMaskArrays(), data.getLabelsMaskArrays());
|
||||||
cg.computeGradientAndScore();
|
cg.computeGradientAndScore();
|
||||||
gradientFlat = cg.getFlattenedGradients();
|
gradientFlat = cg.getFlattenedGradients();
|
||||||
layers = cg.getLayers();
|
layers = cg.getLayers();
|
||||||
|
grad = cg.gradient().gradientForVariable();
|
||||||
|
} else {
|
||||||
|
Map<String,INDArray> ph = tc.getGradientsTestDataSameDiff();
|
||||||
|
List<String> allVars = new ArrayList<>();
|
||||||
|
for(SDVariable v : sd.variables()){
|
||||||
|
if(v.getVariableType() == VariableType.VARIABLE){
|
||||||
|
allVars.add(v.name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
grad = sd.calculateGradients(ph, allVars);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(modelType != ModelType.SAMEDIFF) {
|
||||||
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
|
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
|
||||||
INDArray gradientFlatSaved = read(gFlatFile);
|
INDArray gradientFlatSaved = read(gFlatFile);
|
||||||
|
|
||||||
INDArray gradExceedsRE = exceedsRelError(gradientFlatSaved, gradientFlat, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients());
|
INDArray gradExceedsRE = exceedsRelError(gradientFlatSaved, gradientFlat, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients());
|
||||||
int count = gradExceedsRE.sumNumber().intValue();
|
int count = gradExceedsRE.sumNumber().intValue();
|
||||||
if(count > 0){
|
if (count > 0) {
|
||||||
logFailedParams(20, "Gradient", layers, gradExceedsRE, gradientFlatSaved, gradientFlat);
|
logFailedParams(20, "Gradient", layers, gradExceedsRE, gradientFlatSaved, gradientFlat);
|
||||||
}
|
}
|
||||||
assertEquals("Saved flattened gradients: not equal (using relative error)", 0, count);
|
assertEquals("Saved flattened gradients: not equal (using relative error)", 0, count);
|
||||||
|
}
|
||||||
|
|
||||||
//Load the gradient table:
|
//Load the gradient table:
|
||||||
File gradientDir = new File(testBaseDir, "gradients");
|
File gradientDir = new File(testBaseDir, "gradients");
|
||||||
|
@ -302,12 +364,12 @@ public class IntegrationTestRunner {
|
||||||
String key = f.getName();
|
String key = f.getName();
|
||||||
key = key.substring(0, key.length() - 4); //remove ".bin"
|
key = key.substring(0, key.length() - 4); //remove ".bin"
|
||||||
INDArray loaded = read(f);
|
INDArray loaded = read(f);
|
||||||
INDArray now = m.gradient().gradientForVariable().get(key);
|
INDArray now = grad.get(key);
|
||||||
|
|
||||||
|
|
||||||
gradExceedsRE = exceedsRelError(gradientFlatSaved, gradientFlat, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients());
|
INDArray gradExceedsRE = exceedsRelError(loaded, now, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients());
|
||||||
count = gradExceedsRE.sumNumber().intValue();
|
int count = gradExceedsRE.sumNumber().intValue();
|
||||||
assertEquals("Saved flattened gradients: not equal (using relative error) for parameter: " + key, 0, count);
|
assertEquals("Gradients: not equal (using relative error) for parameter: " + key, 0, count);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -318,7 +380,7 @@ public class IntegrationTestRunner {
|
||||||
|
|
||||||
INDArray paramsPostTraining;
|
INDArray paramsPostTraining;
|
||||||
org.deeplearning4j.nn.api.Layer[] layers;
|
org.deeplearning4j.nn.api.Layer[] layers;
|
||||||
if(isMLN){
|
if(modelType == ModelType.MLN){
|
||||||
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
|
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
|
||||||
Preconditions.checkState(layersToTrain != null, "Layer indices must not be null");
|
Preconditions.checkState(layersToTrain != null, "Layer indices must not be null");
|
||||||
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
||||||
|
@ -328,7 +390,7 @@ public class IntegrationTestRunner {
|
||||||
}
|
}
|
||||||
paramsPostTraining = mln.params();
|
paramsPostTraining = mln.params();
|
||||||
layers = mln.getLayers();
|
layers = mln.getLayers();
|
||||||
} else {
|
} else if(modelType == ModelType.CG) {
|
||||||
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
||||||
Preconditions.checkState(layersToTrain != null, "Layer names must not be null");
|
Preconditions.checkState(layersToTrain != null, "Layer names must not be null");
|
||||||
|
|
||||||
|
@ -337,6 +399,8 @@ public class IntegrationTestRunner {
|
||||||
}
|
}
|
||||||
paramsPostTraining = cg.params();
|
paramsPostTraining = cg.params();
|
||||||
layers = cg.getLayers();
|
layers = cg.getLayers();
|
||||||
|
} else {
|
||||||
|
throw new UnsupportedOperationException("Unsupported layerwise pretraining not supported for SameDiff models");
|
||||||
}
|
}
|
||||||
|
|
||||||
File f = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_UNSUPERVISED_FILENAME);
|
File f = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_UNSUPERVISED_FILENAME);
|
||||||
|
@ -360,53 +424,78 @@ public class IntegrationTestRunner {
|
||||||
MultiDataSetIterator trainData = tc.getTrainingData();
|
MultiDataSetIterator trainData = tc.getTrainingData();
|
||||||
boolean isTbptt;
|
boolean isTbptt;
|
||||||
int tbpttLength;
|
int tbpttLength;
|
||||||
if(isMLN){
|
if(modelType == ModelType.MLN){
|
||||||
isTbptt = mln.getLayerWiseConfigurations().getBackpropType() == BackpropType.TruncatedBPTT;
|
isTbptt = mln.getLayerWiseConfigurations().getBackpropType() == BackpropType.TruncatedBPTT;
|
||||||
tbpttLength = mln.getLayerWiseConfigurations().getTbpttFwdLength();
|
tbpttLength = mln.getLayerWiseConfigurations().getTbpttFwdLength();
|
||||||
} else {
|
} else if(modelType == ModelType.CG) {
|
||||||
isTbptt = cg.getConfiguration().getBackpropType() == BackpropType.TruncatedBPTT;
|
isTbptt = cg.getConfiguration().getBackpropType() == BackpropType.TruncatedBPTT;
|
||||||
tbpttLength = cg.getConfiguration().getTbpttFwdLength();
|
tbpttLength = cg.getConfiguration().getTbpttFwdLength();
|
||||||
|
} else {
|
||||||
|
isTbptt = false;
|
||||||
|
tbpttLength = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
CountingMultiDataSetIterator countingIter = new CountingMultiDataSetIterator(trainData, isTbptt, tbpttLength);
|
CountingMultiDataSetIterator countingIter = new CountingMultiDataSetIterator(trainData, isTbptt, tbpttLength);
|
||||||
CollectScoresListener l = new CollectScoresListener(1);
|
CollectScoresListener l = new CollectScoresListener(1);
|
||||||
|
if(modelType != ModelType.SAMEDIFF) {
|
||||||
m.setListeners(l);
|
m.setListeners(l);
|
||||||
|
}
|
||||||
|
|
||||||
int iterBefore;
|
int iterBefore;
|
||||||
int epochBefore;
|
int epochBefore;
|
||||||
int iterAfter;
|
int iterAfter;
|
||||||
int epochAfter;
|
int epochAfter;
|
||||||
|
|
||||||
Map<String,INDArray> frozenParamsBefore = getFrozenLayerParamCopies(m);
|
Map<String,INDArray> frozenParamsBefore = modelType != ModelType.SAMEDIFF ? getFrozenLayerParamCopies(m) : getConstantCopies(sd);
|
||||||
org.deeplearning4j.nn.api.Layer[] layers;
|
org.deeplearning4j.nn.api.Layer[] layers = null;
|
||||||
if (isMLN) {
|
History h = null;
|
||||||
|
if (modelType == ModelType.MLN) {
|
||||||
iterBefore = mln.getIterationCount();
|
iterBefore = mln.getIterationCount();
|
||||||
epochBefore = mln.getEpochCount();
|
epochBefore = mln.getEpochCount();
|
||||||
mln.fit(countingIter);
|
mln.fit(countingIter);
|
||||||
iterAfter = mln.getIterationCount();
|
iterAfter = mln.getIterationCount();
|
||||||
epochAfter = mln.getEpochCount();
|
epochAfter = mln.getEpochCount();
|
||||||
layers = mln.getLayers();
|
layers = mln.getLayers();
|
||||||
} else {
|
} else if(modelType == ModelType.CG){
|
||||||
iterBefore = cg.getConfiguration().getIterationCount();
|
iterBefore = cg.getConfiguration().getIterationCount();
|
||||||
epochBefore = cg.getConfiguration().getEpochCount();
|
epochBefore = cg.getConfiguration().getEpochCount();
|
||||||
cg.fit(countingIter);
|
cg.fit(countingIter);
|
||||||
iterAfter = cg.getConfiguration().getIterationCount();
|
iterAfter = cg.getConfiguration().getIterationCount();
|
||||||
epochAfter = cg.getConfiguration().getEpochCount();
|
epochAfter = cg.getConfiguration().getEpochCount();
|
||||||
layers = cg.getLayers();
|
layers = cg.getLayers();
|
||||||
|
} else {
|
||||||
|
iterBefore = sd.getTrainingConfig().getIterationCount();
|
||||||
|
epochBefore = sd.getTrainingConfig().getEpochCount();
|
||||||
|
h = sd.fit(countingIter, 1);
|
||||||
|
iterAfter = sd.getTrainingConfig().getIterationCount();
|
||||||
|
epochAfter = sd.getTrainingConfig().getEpochCount();
|
||||||
}
|
}
|
||||||
|
|
||||||
//Check that frozen params (if any) haven't changed during training:
|
//Check that frozen params (if any) haven't changed during training:
|
||||||
|
if(modelType == ModelType.SAMEDIFF) {
|
||||||
|
checkConstants(frozenParamsBefore, sd);
|
||||||
|
} else {
|
||||||
checkFrozenParams(frozenParamsBefore, m);
|
checkFrozenParams(frozenParamsBefore, m);
|
||||||
|
}
|
||||||
|
|
||||||
//Validate the iteration and epoch counts - both for the net, and for the layers
|
//Validate the iteration and epoch counts - both for the net, and for the layers
|
||||||
int newIters = countingIter.getCurrIter();
|
int newIters = countingIter.getCurrIter();
|
||||||
assertEquals(iterBefore + newIters, iterAfter);
|
assertEquals(iterBefore + newIters, iterAfter);
|
||||||
assertEquals(epochBefore + 1, epochAfter);
|
assertEquals(epochBefore + 1, epochAfter);
|
||||||
validateLayerIterCounts(m, epochBefore + 1, iterBefore+newIters); //TODO CURRENTLY FAILING
|
if(modelType != ModelType.SAMEDIFF) {
|
||||||
double[] scores = l.getListScore().toDoubleArray();
|
validateLayerIterCounts(m, epochBefore + 1, iterBefore + newIters);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
double[] scores;
|
||||||
|
if(modelType == ModelType.SAMEDIFF){
|
||||||
|
scores = h.lossCurve().getLossValues().toDoubleVector();
|
||||||
|
} else {
|
||||||
|
scores = l.getListScore().toDoubleArray();
|
||||||
|
}
|
||||||
|
|
||||||
File f = new File(testBaseDir, IntegrationTestRunner.TRAINING_CURVE_FILENAME);
|
File f = new File(testBaseDir, IntegrationTestRunner.TRAINING_CURVE_FILENAME);
|
||||||
String[] s = FileUtils.readFileToString(f).split(",");
|
String[] s = FileUtils.readFileToString(f, StandardCharsets.UTF_8).split(",");
|
||||||
|
|
||||||
if(tc.isTestTrainingCurves()) {
|
if(tc.isTestTrainingCurves()) {
|
||||||
assertEquals("Different number of scores", s.length, scores.length);
|
assertEquals("Different number of scores", s.length, scores.length);
|
||||||
|
@ -426,18 +515,37 @@ public class IntegrationTestRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tc.isTestParamsPostTraining()) {
|
if (tc.isTestParamsPostTraining()) {
|
||||||
|
if(modelType != ModelType.SAMEDIFF) {
|
||||||
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
|
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
|
||||||
INDArray paramsExp = read(p);
|
INDArray paramsExp = read(p);
|
||||||
INDArray z = exceedsRelError(m.params(), paramsExp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining());
|
INDArray z = exceedsRelError(m.params(), paramsExp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining());
|
||||||
int count = z.sumNumber().intValue();
|
int count = z.sumNumber().intValue();
|
||||||
if(count > 0){
|
if (count > 0) {
|
||||||
logFailedParams(20, "Parameter", layers, z, paramsExp, m.params());
|
logFailedParams(20, "Parameter", layers, z, paramsExp, m.params());
|
||||||
}
|
}
|
||||||
assertEquals("Number of params exceeded max relative error", 0, count);
|
assertEquals("Number of params exceeded max relative error", 0, count);
|
||||||
|
} else {
|
||||||
|
File dir = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_SAMEDIFF_DIR);
|
||||||
|
for(SDVariable v : sd.variables()){
|
||||||
|
if(v.getVariableType() != VariableType.VARIABLE)
|
||||||
|
continue;
|
||||||
|
INDArray paramNow = v.getArr();
|
||||||
|
File paramFile = new File(dir, v.name() + ".bin");
|
||||||
|
INDArray exp = read(paramFile);
|
||||||
|
INDArray z = exceedsRelError(paramNow, exp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining());
|
||||||
|
int count = z.sumNumber().intValue();
|
||||||
|
if (count > 0) {
|
||||||
|
logFailedParams(20, "Parameter: " + v.name(), layers, z, exp, paramNow);
|
||||||
|
}
|
||||||
|
assertEquals("Number of params exceeded max relative error for parameter: \"" + v.name() + "\"", 0, count);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(modelType != ModelType.SAMEDIFF) {
|
||||||
checkLayerClearance(m);
|
checkLayerClearance(m);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//Check evaluation:
|
//Check evaluation:
|
||||||
if (tc.isTestEvaluation()) {
|
if (tc.isTestEvaluation()) {
|
||||||
|
@ -445,17 +553,19 @@ public class IntegrationTestRunner {
|
||||||
IEvaluation[] evals = tc.getNewEvaluations();
|
IEvaluation[] evals = tc.getNewEvaluations();
|
||||||
MultiDataSetIterator iter = tc.getEvaluationTestData();
|
MultiDataSetIterator iter = tc.getEvaluationTestData();
|
||||||
|
|
||||||
if (isMLN) {
|
if (modelType == ModelType.MLN) {
|
||||||
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
||||||
mln.doEvaluation(dsi, evals);
|
mln.doEvaluation(dsi, evals);
|
||||||
} else {
|
} else if(modelType == ModelType.CG){
|
||||||
cg.doEvaluation(iter, evals);
|
cg.doEvaluation(iter, evals);
|
||||||
|
} else {
|
||||||
|
evals = tc.doEvaluationSameDiff(sd, iter, evals);
|
||||||
}
|
}
|
||||||
|
|
||||||
File evalDir = new File(testBaseDir, "evaluation");
|
File evalDir = new File(testBaseDir, "evaluation");
|
||||||
for (int i = 0; i < evals.length; i++) {
|
for (int i = 0; i < evals.length; i++) {
|
||||||
File f = new File(evalDir, i + "." + evals[i].getClass().getSimpleName() + ".json");
|
File f = new File(evalDir, i + "." + evals[i].getClass().getSimpleName() + ".json");
|
||||||
String json = FileUtils.readFileToString(f);
|
String json = FileUtils.readFileToString(f, StandardCharsets.UTF_8);
|
||||||
IEvaluation e;
|
IEvaluation e;
|
||||||
if (evals[i].getClass() == Evaluation.class) {
|
if (evals[i].getClass() == Evaluation.class) {
|
||||||
e = Evaluation.fromJson(json);
|
e = Evaluation.fromJson(json);
|
||||||
|
@ -479,9 +589,11 @@ public class IntegrationTestRunner {
|
||||||
//Evaluation coverage information:
|
//Evaluation coverage information:
|
||||||
evaluationClassesSeen.put(evals[i].getClass(), evaluationClassesSeen.getOrDefault(evals[i].getClass(), 0) + 1);
|
evaluationClassesSeen.put(evals[i].getClass(), evaluationClassesSeen.getOrDefault(evals[i].getClass(), 0) + 1);
|
||||||
|
|
||||||
|
if(modelType != ModelType.SAMEDIFF) {
|
||||||
checkLayerClearance(m);
|
checkLayerClearance(m);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//Check model serialization
|
//Check model serialization
|
||||||
{
|
{
|
||||||
|
@ -490,15 +602,20 @@ public class IntegrationTestRunner {
|
||||||
File f = testDir.newFile();
|
File f = testDir.newFile();
|
||||||
f.delete();
|
f.delete();
|
||||||
|
|
||||||
|
if (modelType == ModelType.MLN) {
|
||||||
ModelSerializer.writeModel(m, f, true);
|
ModelSerializer.writeModel(m, f, true);
|
||||||
if (isMLN) {
|
|
||||||
MultiLayerNetwork restored = MultiLayerNetwork.load(f, true);
|
MultiLayerNetwork restored = MultiLayerNetwork.load(f, true);
|
||||||
assertEquals(mln.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
|
assertEquals(mln.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
|
||||||
assertEquals(mln.params(), restored.params());
|
assertEquals(mln.params(), restored.params());
|
||||||
} else {
|
} else if(modelType == ModelType.CG){
|
||||||
|
ModelSerializer.writeModel(m, f, true);
|
||||||
ComputationGraph restored = ComputationGraph.load(f, true);
|
ComputationGraph restored = ComputationGraph.load(f, true);
|
||||||
assertEquals(cg.getConfiguration(), restored.getConfiguration());
|
assertEquals(cg.getConfiguration(), restored.getConfiguration());
|
||||||
assertEquals(cg.params(), restored.params());
|
assertEquals(cg.params(), restored.params());
|
||||||
|
} else {
|
||||||
|
sd.save(f, true);
|
||||||
|
SameDiff restored = SameDiff.load(f, true);
|
||||||
|
assertSameDiffEquals(sd, restored);
|
||||||
}
|
}
|
||||||
|
|
||||||
System.gc();
|
System.gc();
|
||||||
|
@ -506,7 +623,7 @@ public class IntegrationTestRunner {
|
||||||
|
|
||||||
|
|
||||||
//Check parallel inference
|
//Check parallel inference
|
||||||
if (tc.isTestParallelInference()) {
|
if (modelType != ModelType.SAMEDIFF && tc.isTestParallelInference()) {
|
||||||
|
|
||||||
List<Pair<INDArray[], INDArray[]>> inputs = tc.getPredictionsTestData();
|
List<Pair<INDArray[], INDArray[]>> inputs = tc.getPredictionsTestData();
|
||||||
|
|
||||||
|
@ -515,7 +632,7 @@ public class IntegrationTestRunner {
|
||||||
List<INDArray[]> exp = new ArrayList<>();
|
List<INDArray[]> exp = new ArrayList<>();
|
||||||
for(Pair<INDArray[], INDArray[]> p : inputs){
|
for(Pair<INDArray[], INDArray[]> p : inputs){
|
||||||
INDArray[] out;
|
INDArray[] out;
|
||||||
if(isMLN){
|
if(modelType == ModelType.MLN){
|
||||||
INDArray fm = p.getSecond() == null ? null : p.getSecond()[0];
|
INDArray fm = p.getSecond() == null ? null : p.getSecond()[0];
|
||||||
out = new INDArray[]{mln.output(p.getFirst()[0], false, fm, null)};
|
out = new INDArray[]{mln.output(p.getFirst()[0], false, fm, null)};
|
||||||
} else {
|
} else {
|
||||||
|
@ -547,38 +664,55 @@ public class IntegrationTestRunner {
|
||||||
|
|
||||||
MultiDataSet toOverfit = tc.getOverfittingData();
|
MultiDataSet toOverfit = tc.getOverfittingData();
|
||||||
for (int i = 0; i < tc.getOverfitNumIterations(); i++) {
|
for (int i = 0; i < tc.getOverfitNumIterations(); i++) {
|
||||||
if (isMLN) {
|
if (modelType == ModelType.MLN) {
|
||||||
mln.fit(toOverfit);
|
mln.fit(toOverfit);
|
||||||
} else {
|
} else if(modelType == ModelType.CG){
|
||||||
cg.fit(toOverfit);
|
cg.fit(toOverfit);
|
||||||
|
} else {
|
||||||
|
sd.fit(toOverfit);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//Check:
|
//Check:
|
||||||
INDArray[] output;
|
INDArray[] output = null;
|
||||||
if (isMLN) {
|
Map<String,INDArray> outSd = null;
|
||||||
|
if (modelType == ModelType.MLN) {
|
||||||
mln.setLayerMaskArrays(toOverfit.getFeaturesMaskArray(0), null);
|
mln.setLayerMaskArrays(toOverfit.getFeaturesMaskArray(0), null);
|
||||||
output = new INDArray[]{mln.output(toOverfit.getFeatures(0))};
|
output = new INDArray[]{mln.output(toOverfit.getFeatures(0))};
|
||||||
} else {
|
} else if(modelType == ModelType.CG ){
|
||||||
cg.setLayerMaskArrays(toOverfit.getFeaturesMaskArrays(), null);
|
cg.setLayerMaskArrays(toOverfit.getFeaturesMaskArrays(), null);
|
||||||
output = cg.output(toOverfit.getFeatures());
|
output = cg.output(toOverfit.getFeatures());
|
||||||
|
} else {
|
||||||
|
List<String> l = sd.getTrainingConfig().getDataSetFeatureMapping();
|
||||||
|
Map<String,INDArray> phMap = new HashMap<>();
|
||||||
|
int i=0;
|
||||||
|
for(String s : l){
|
||||||
|
phMap.put(s, toOverfit.getFeatures(i++));
|
||||||
|
}
|
||||||
|
outSd = sd.output(phMap, tc.getPredictionsNamesSameDiff());
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < output.length; i++) {
|
int n = modelType == ModelType.SAMEDIFF ? outSd.size() : output.length;
|
||||||
INDArray z = exceedsRelError(output[i], toOverfit.getLabels(i), tc.getMaxRelativeErrorOverfit(), tc.getMinAbsErrorOverfit());
|
for (int i = 0; i < n; i++) {
|
||||||
|
INDArray out = modelType == ModelType.SAMEDIFF ? outSd.get(tc.getPredictionsNamesSameDiff().get(i)) : output[i];
|
||||||
|
INDArray label = toOverfit.getLabels(i);
|
||||||
|
|
||||||
|
INDArray z = exceedsRelError(out, label, tc.getMaxRelativeErrorOverfit(), tc.getMinAbsErrorOverfit());
|
||||||
int count = z.sumNumber().intValue();
|
int count = z.sumNumber().intValue();
|
||||||
if (count > 0) {
|
if (count > 0) {
|
||||||
System.out.println(output[i]);
|
System.out.println(out);
|
||||||
System.out.println(toOverfit.getLabels(i));
|
System.out.println(label);
|
||||||
INDArray re = relativeError(output[i], toOverfit.getLabels(i), tc.getMinAbsErrorOverfit());
|
INDArray re = relativeError(out, label, tc.getMinAbsErrorOverfit());
|
||||||
System.out.println("Relative error:");
|
System.out.println("Relative error:");
|
||||||
System.out.println(re);
|
System.out.println(re);
|
||||||
}
|
}
|
||||||
assertEquals("Number of outputs exceeded max relative error", 0, count);
|
assertEquals("Number of outputs exceeded max relative error", 0, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(modelType != ModelType.SAMEDIFF) {
|
||||||
checkLayerClearance(m);
|
checkLayerClearance(m);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
long end = System.currentTimeMillis();
|
long end = System.currentTimeMillis();
|
||||||
|
|
||||||
|
@ -709,6 +843,16 @@ public class IntegrationTestRunner {
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static Map<String,INDArray> getConstantCopies(SameDiff sd){
|
||||||
|
Map<String,INDArray> out = new HashMap<>();
|
||||||
|
for(SDVariable v : sd.variables()){
|
||||||
|
if(v.isConstant()){
|
||||||
|
out.put(v.name(), v.getArr());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
public static void checkFrozenParams(Map<String,INDArray> copiesBeforeTraining, Model m){
|
public static void checkFrozenParams(Map<String,INDArray> copiesBeforeTraining, Model m){
|
||||||
for(Map.Entry<String,INDArray> e : copiesBeforeTraining.entrySet()){
|
for(Map.Entry<String,INDArray> e : copiesBeforeTraining.entrySet()){
|
||||||
INDArray actual = m.getParam(e.getKey());
|
INDArray actual = m.getParam(e.getKey());
|
||||||
|
@ -716,6 +860,13 @@ public class IntegrationTestRunner {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static void checkConstants(Map<String,INDArray> copiesBefore, SameDiff sd){
|
||||||
|
for(Map.Entry<String,INDArray> e : copiesBefore.entrySet()){
|
||||||
|
INDArray actual = sd.getArrForVarName(e.getKey());
|
||||||
|
assertEquals(e.getKey(), e.getValue(), actual);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public static void printCoverageInformation(){
|
public static void printCoverageInformation(){
|
||||||
|
|
||||||
log.info("||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||");
|
log.info("||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||");
|
||||||
|
@ -918,7 +1069,7 @@ public class IntegrationTestRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static void logFailedParams(int maxNum, String prefix, org.deeplearning4j.nn.api.Layer[] layers, INDArray exceedsRelError, INDArray exp, INDArray act){
|
public static void logFailedParams(int maxNumToPrintOnFailure, String prefix, org.deeplearning4j.nn.api.Layer[] layers, INDArray exceedsRelError, INDArray exp, INDArray act){
|
||||||
long length = exceedsRelError.length();
|
long length = exceedsRelError.length();
|
||||||
int logCount = 0;
|
int logCount = 0;
|
||||||
for(int i=0; i<length; i++ ){
|
for(int i=0; i<length; i++ ){
|
||||||
|
@ -947,10 +1098,33 @@ public class IntegrationTestRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
log.info("{} {} ({}) failed: expected {} vs actual {} (RelativeError: {}, AbsError: {})", i, prefix, pName, dExp, dAct, re, ae);
|
log.info("{} {} ({}) failed: expected {} vs actual {} (RelativeError: {}, AbsError: {})", i, prefix, pName, dExp, dAct, re, ae);
|
||||||
if(++logCount >= maxNum){
|
if(++logCount >= maxNumToPrintOnFailure){
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static void assertSameDiffEquals(SameDiff sd1, SameDiff sd2){
|
||||||
|
assertEquals(sd1.variableMap().keySet(), sd2.variableMap().keySet());
|
||||||
|
assertEquals(sd1.getOps().keySet(), sd2.getOps().keySet());
|
||||||
|
assertEquals(sd1.inputs(), sd2.inputs());
|
||||||
|
|
||||||
|
//Check constant and variable arrays:
|
||||||
|
for(SDVariable v : sd1.variables()){
|
||||||
|
String n = v.name();
|
||||||
|
assertEquals(n, v.getVariableType(), sd2.getVariable(n).getVariableType());
|
||||||
|
if(v.isConstant() || v.getVariableType() == VariableType.VARIABLE){
|
||||||
|
INDArray a1 = v.getArr();
|
||||||
|
INDArray a2 = sd2.getVariable(n).getArr();
|
||||||
|
assertEquals(n, a1, a2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//Check ops:
|
||||||
|
for(SameDiffOp o : sd1.getOps().values()){
|
||||||
|
SameDiffOp o2 = sd2.getOps().get(o.getName());
|
||||||
|
assertEquals(o.getOp().getClass(), o2.getOp().getClass());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -17,15 +18,19 @@
|
||||||
package org.deeplearning4j.integration;
|
package org.deeplearning4j.integration;
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.integration.testcases.*;
|
import org.deeplearning4j.integration.testcases.dl4j.*;
|
||||||
import org.junit.AfterClass;
|
import org.junit.AfterClass;
|
||||||
import org.junit.Ignore;
|
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
|
||||||
@Ignore("AB - 2019/05/27 - Integration tests need to be updated")
|
//@Ignore("AB - 2019/05/27 - Integration tests need to be updated")
|
||||||
public class IntegrationTests extends BaseDL4JTest {
|
public class IntegrationTestsDL4J extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 300_000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
@ -36,79 +41,72 @@ public class IntegrationTests extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ***** MLPTestCases *****
|
// ***** MLPTestCases *****
|
||||||
@Test(timeout = 20000L)
|
@Test
|
||||||
public void testMLPMnist() throws Exception {
|
public void testMLPMnist() throws Exception {
|
||||||
IntegrationTestRunner.runTest(MLPTestCases.getMLPMnist(), testDir);
|
IntegrationTestRunner.runTest(MLPTestCases.getMLPMnist(), testDir);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 30000L)
|
@Test
|
||||||
public void testMlpMoon() throws Exception {
|
public void testMlpMoon() throws Exception {
|
||||||
IntegrationTestRunner.runTest(MLPTestCases.getMLPMoon(), testDir);
|
IntegrationTestRunner.runTest(MLPTestCases.getMLPMoon(), testDir);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ***** RNNTestCases *****
|
// ***** RNNTestCases *****
|
||||||
@Test(timeout = 30000L)
|
@Test
|
||||||
public void testRnnSeqClassification1() throws Exception {
|
public void testRnnSeqClassification1() throws Exception {
|
||||||
IntegrationTestRunner.runTest(RNNTestCases.getRnnCsvSequenceClassificationTestCase1(), testDir);
|
IntegrationTestRunner.runTest(RNNTestCases.getRnnCsvSequenceClassificationTestCase1(), testDir);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test
|
||||||
public void testRnnSeqClassification2() throws Exception {
|
public void testRnnSeqClassification2() throws Exception {
|
||||||
IntegrationTestRunner.runTest(RNNTestCases.getRnnCsvSequenceClassificationTestCase2(), testDir);
|
IntegrationTestRunner.runTest(RNNTestCases.getRnnCsvSequenceClassificationTestCase2(), testDir);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testRnnCharacter() throws Exception {
|
public void testRnnCharacter() throws Exception {
|
||||||
IntegrationTestRunner.runTest(RNNTestCases.getRnnCharacterTestCase(), testDir);
|
IntegrationTestRunner.runTest(RNNTestCases.getRnnCharacterTestCase(), testDir);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// ***** CNN1DTestCases *****
|
// ***** CNN1DTestCases *****
|
||||||
@Test(timeout = 180000L)
|
@Test
|
||||||
public void testCnn1dCharacter() throws Exception {
|
public void testCnn1dCharacter() throws Exception {
|
||||||
IntegrationTestRunner.runTest(CNN1DTestCases.getCnn1dTestCaseCharRNN(), testDir);
|
IntegrationTestRunner.runTest(CNN1DTestCases.getCnn1dTestCaseCharRNN(), testDir);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// ***** CNN2DTestCases *****
|
// ***** CNN2DTestCases *****
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testLenetMnist() throws Exception {
|
public void testLenetMnist() throws Exception {
|
||||||
IntegrationTestRunner.runTest(CNN2DTestCases.getLenetMnist(), testDir);
|
IntegrationTestRunner.runTest(CNN2DTestCases.getLenetMnist(), testDir);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Ignore //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6017
|
@Test
|
||||||
@Test(timeout = 180000L)
|
|
||||||
public void testYoloHouseNumbers() throws Exception {
|
public void testYoloHouseNumbers() throws Exception {
|
||||||
IntegrationTestRunner.runTest(CNN2DTestCases.getYoloHouseNumbers(), testDir);
|
IntegrationTestRunner.runTest(CNN2DTestCases.getYoloHouseNumbers(), testDir);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testCnn2DLenetTransferDropoutRepeatability() throws Exception {
|
public void testCnn2DLenetTransferDropoutRepeatability() throws Exception {
|
||||||
IntegrationTestRunner.runTest(CNN2DTestCases.testLenetTransferDropoutRepeatability(), testDir);
|
IntegrationTestRunner.runTest(CNN2DTestCases.testLenetTransferDropoutRepeatability(), testDir);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// ***** CNN3DTestCases *****
|
// ***** CNN3DTestCases *****
|
||||||
@Test(timeout = 180000L)
|
@Test
|
||||||
public void testCnn3dSynthetic() throws Exception {
|
public void testCnn3dSynthetic() throws Exception {
|
||||||
IntegrationTestRunner.runTest(CNN3DTestCases.getCnn3dTestCaseSynthetic(), testDir);
|
IntegrationTestRunner.runTest(CNN3DTestCases.getCnn3dTestCaseSynthetic(), testDir);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// ***** UnsupervisedTestCases *****
|
// ***** UnsupervisedTestCases *****
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testVAEMnistAnomaly() throws Exception {
|
public void testVAEMnistAnomaly() throws Exception {
|
||||||
IntegrationTestRunner.runTest(UnsupervisedTestCases.getVAEMnistAnomaly(), testDir);
|
IntegrationTestRunner.runTest(UnsupervisedTestCases.getVAEMnistAnomaly(), testDir);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ***** TransferLearningTestCases *****
|
// ***** TransferLearningTestCases *****
|
||||||
@Test(timeout = 360000L)
|
@Test
|
||||||
public void testVgg16Transfer() throws Exception {
|
public void testVgg16Transfer() throws Exception {
|
||||||
IntegrationTestRunner.runTest(CNN2DTestCases.getVGG16TransferTinyImagenet(), testDir);
|
IntegrationTestRunner.runTest(CNN2DTestCases.getVGG16TransferTinyImagenet(), testDir);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// ***** KerasImportTestCases *****
|
|
||||||
//TODO
|
|
||||||
|
|
||||||
}
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.integration;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
|
||||||
|
import org.junit.Rule;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
|
||||||
|
public class IntegrationTestsSameDiff extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 300_000L;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMLPMnist() throws Exception {
|
||||||
|
IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMnist(), testDir);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,20 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.integration;
|
||||||
|
|
||||||
|
public enum ModelType {
|
||||||
|
MLN, CG, SAMEDIFF
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -17,8 +18,9 @@
|
||||||
package org.deeplearning4j.integration;
|
package org.deeplearning4j.integration;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.deeplearning4j.eval.IEvaluation;
|
|
||||||
import org.deeplearning4j.nn.api.Model;
|
import org.deeplearning4j.nn.api.Model;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
|
@ -26,6 +28,7 @@ import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A single test case for integration tests
|
* A single test case for integration tests
|
||||||
|
@ -37,16 +40,17 @@ public abstract class TestCase {
|
||||||
PRETRAINED, RANDOM_INIT
|
PRETRAINED, RANDOM_INIT
|
||||||
}
|
}
|
||||||
|
|
||||||
protected String testName;
|
//See: readme.md for more details
|
||||||
protected TestType testType;
|
protected String testName; //Name of the test, for display purposes
|
||||||
protected boolean testPredictions = true;
|
protected TestType testType; //Type of model - from a pretrained model, or a randomly initialized model
|
||||||
protected boolean testGradients = true;
|
protected boolean testPredictions = true; //If true: check the predictions/output. Requires getPredictionsTestData() to be implemented
|
||||||
protected boolean testUnsupervisedTraining = false;
|
protected boolean testGradients = true; //If true: check the gradients. Requires getGradientsTestData() to be implemented
|
||||||
protected boolean testTrainingCurves = true;
|
protected boolean testUnsupervisedTraining = false; //If true: perform unsupervised training. Only applies to layers like autoencoders, VAEs, etc. Requires getUnsupervisedTrainData() to be implemented
|
||||||
protected boolean testParamsPostTraining = true;
|
protected boolean testTrainingCurves = true; //If true: perform training, and compare loss vs. iteration. Requires getTrainingData() method
|
||||||
protected boolean testEvaluation = true;
|
protected boolean testParamsPostTraining = true; //If true: perform training, and compare parameters after training. Requires getTrainingData() method
|
||||||
protected boolean testParallelInference = true;
|
protected boolean testEvaluation = true; //If true: perform evaluation. Requires getNewEvaluations() and getEvaluationTestData() methods implemented
|
||||||
protected boolean testOverfitting = true;
|
protected boolean testParallelInference = true; //If true: run the model through ParallelInference. Requires getPredictionsTestData() method. Only applies to DL4J models, NOT SameDiff models
|
||||||
|
protected boolean testOverfitting = true; //If true: perform overfitting, and ensure the predictions match the training data. Requires both getOverfittingData() and getOverfitNumIterations()
|
||||||
|
|
||||||
protected int[] unsupervisedTrainLayersMLN = null;
|
protected int[] unsupervisedTrainLayersMLN = null;
|
||||||
protected String[] unsupervisedTrainLayersCG = null;
|
protected String[] unsupervisedTrainLayersCG = null;
|
||||||
|
@ -65,6 +69,8 @@ public abstract class TestCase {
|
||||||
protected double maxRelativeErrorOverfit = 1e-2;
|
protected double maxRelativeErrorOverfit = 1e-2;
|
||||||
protected double minAbsErrorOverfit = 1e-2;
|
protected double minAbsErrorOverfit = 1e-2;
|
||||||
|
|
||||||
|
public abstract ModelType modelType();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialize the test case... many tests don't need this; others may use it to download or create data
|
* Initialize the test case... many tests don't need this; others may use it to download or create data
|
||||||
* @param testWorkingDir Working directory to use for test
|
* @param testWorkingDir Working directory to use for test
|
||||||
|
@ -88,19 +94,37 @@ public abstract class TestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Required if testPredictions == true
|
* Required if testPredictions == true && DL4J model (MultiLayerNetwork or ComputationGraph)
|
||||||
*/
|
*/
|
||||||
public List<Pair<INDArray[],INDArray[]>> getPredictionsTestData() throws Exception {
|
public List<Pair<INDArray[],INDArray[]>> getPredictionsTestData() throws Exception {
|
||||||
throw new RuntimeException("Implementations must override this method if used");
|
throw new RuntimeException("Implementations must override this method if used");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Required if testGradients == true
|
* Required if testPredictions == true && SameDiff model
|
||||||
|
*/
|
||||||
|
public List<Map<String,INDArray>> getPredictionsTestDataSameDiff() throws Exception {
|
||||||
|
throw new RuntimeException("Implementations must override this method if used");
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> getPredictionsNamesSameDiff() throws Exception {
|
||||||
|
throw new RuntimeException("Implementations must override this method if used");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Required if testGradients == true && DL4J model
|
||||||
*/
|
*/
|
||||||
public MultiDataSet getGradientsTestData() throws Exception {
|
public MultiDataSet getGradientsTestData() throws Exception {
|
||||||
throw new RuntimeException("Implementations must override this method if used");
|
throw new RuntimeException("Implementations must override this method if used");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Required if testGradients == true && SameDiff model
|
||||||
|
*/
|
||||||
|
public Map<String,INDArray> getGradientsTestDataSameDiff() throws Exception {
|
||||||
|
throw new RuntimeException("Implementations must override this method if used");
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Required when testUnsupervisedTraining == true
|
* Required when testUnsupervisedTraining == true
|
||||||
*/
|
*/
|
||||||
|
@ -122,6 +146,10 @@ public abstract class TestCase {
|
||||||
throw new RuntimeException("Implementations must override this method if used");
|
throw new RuntimeException("Implementations must override this method if used");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations){
|
||||||
|
throw new RuntimeException("Implementations must override this method if used");
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Required if testEvaluation == true
|
* Required if testEvaluation == true
|
||||||
*/
|
*/
|
||||||
|
@ -130,12 +158,19 @@ public abstract class TestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Required if testOverfitting == true
|
* Required if testOverfitting == true && DL4J model
|
||||||
*/
|
*/
|
||||||
public MultiDataSet getOverfittingData() throws Exception {
|
public MultiDataSet getOverfittingData() throws Exception {
|
||||||
throw new RuntimeException("Implementations must override this method if used");
|
throw new RuntimeException("Implementations must override this method if used");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Required if testOverfitting == true && SameDiff model
|
||||||
|
*/
|
||||||
|
public Map<String,INDArray> getOverfittingDataSameDiff() throws Exception {
|
||||||
|
throw new RuntimeException("Implementations must override this method if used");
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Required if testOverfitting == true
|
* Required if testOverfitting == true
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -1,36 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.integration.testcases;
|
|
||||||
|
|
||||||
import org.deeplearning4j.integration.TestCase;
|
|
||||||
|
|
||||||
public class TransferLearningTestCases {
|
|
||||||
|
|
||||||
public static TestCase testPartFrozenResNet50(){
|
|
||||||
|
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public static TestCase testPartFrozenNASNET(){
|
|
||||||
|
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -14,22 +15,24 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.integration.testcases;
|
package org.deeplearning4j.integration.testcases.dl4j;
|
||||||
|
|
||||||
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||||
import org.deeplearning4j.eval.Evaluation;
|
import org.deeplearning4j.integration.ModelType;
|
||||||
import org.deeplearning4j.eval.IEvaluation;
|
|
||||||
import org.deeplearning4j.eval.ROCMultiClass;
|
|
||||||
import org.deeplearning4j.integration.TestCase;
|
import org.deeplearning4j.integration.TestCase;
|
||||||
import org.deeplearning4j.integration.testcases.misc.CharacterIterator;
|
import org.deeplearning4j.integration.testcases.dl4j.misc.CharacterIterator;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
|
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
@ -64,12 +67,18 @@ public class CNN1DTestCases {
|
||||||
int miniBatchSize = 16;
|
int miniBatchSize = 16;
|
||||||
int exampleLength = 128;
|
int exampleLength = 128;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.CG;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Object getConfiguration() throws Exception {
|
public Object getConfiguration() throws Exception {
|
||||||
CharacterIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength);
|
CharacterIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength);
|
||||||
int nOut = iter.totalOutcomes();
|
int nOut = iter.totalOutcomes();
|
||||||
|
|
||||||
return new NeuralNetConfiguration.Builder()
|
return new NeuralNetConfiguration.Builder()
|
||||||
|
.dataType(DataType.FLOAT)
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.updater(new Adam(0.01))
|
.updater(new Adam(0.01))
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -14,7 +15,7 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.integration.testcases;
|
package org.deeplearning4j.integration.testcases.dl4j;
|
||||||
|
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.image.loader.NativeImageLoader;
|
import org.datavec.image.loader.NativeImageLoader;
|
||||||
|
@ -22,16 +23,13 @@ import org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader;
|
||||||
import org.datavec.image.recordreader.objdetect.impl.SvhnLabelProvider;
|
import org.datavec.image.recordreader.objdetect.impl.SvhnLabelProvider;
|
||||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.fetchers.SvhnDataFetcher;
|
import org.deeplearning4j.datasets.fetchers.SvhnDataFetcher;
|
||||||
|
import org.deeplearning4j.integration.ModelType;
|
||||||
import org.deeplearning4j.integration.TestCase;
|
import org.deeplearning4j.integration.TestCase;
|
||||||
import org.deeplearning4j.datasets.fetchers.DataSetType;
|
import org.deeplearning4j.datasets.fetchers.DataSetType;
|
||||||
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.TinyImageNetDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.TinyImageNetDataSetIterator;
|
||||||
import org.deeplearning4j.eval.Evaluation;
|
|
||||||
import org.deeplearning4j.eval.EvaluationCalibration;
|
|
||||||
import org.deeplearning4j.eval.IEvaluation;
|
|
||||||
import org.deeplearning4j.eval.ROCMultiClass;
|
|
||||||
import org.deeplearning4j.nn.api.Model;
|
import org.deeplearning4j.nn.api.Model;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.*;
|
import org.deeplearning4j.nn.conf.*;
|
||||||
|
@ -47,7 +45,12 @@ import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.zoo.PretrainedType;
|
import org.deeplearning4j.zoo.PretrainedType;
|
||||||
import org.deeplearning4j.zoo.model.TinyYOLO;
|
import org.deeplearning4j.zoo.model.TinyYOLO;
|
||||||
import org.deeplearning4j.zoo.model.VGG16;
|
import org.deeplearning4j.zoo.model.VGG16;
|
||||||
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||||
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
@ -82,12 +85,18 @@ public class CNN2DTestCases {
|
||||||
testOverfitting = false;
|
testOverfitting = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.MLN;
|
||||||
|
}
|
||||||
|
|
||||||
public Object getConfiguration() throws Exception {
|
public Object getConfiguration() throws Exception {
|
||||||
int nChannels = 1; // Number of input channels
|
int nChannels = 1; // Number of input channels
|
||||||
int outputNum = 10; // The number of possible outcomes
|
int outputNum = 10; // The number of possible outcomes
|
||||||
int seed = 123;
|
int seed = 123;
|
||||||
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.dataType(DataType.FLOAT)
|
||||||
.seed(seed)
|
.seed(seed)
|
||||||
.l2(0.0005)
|
.l2(0.0005)
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
@ -187,6 +196,11 @@ public class CNN2DTestCases {
|
||||||
testOverfitting = false;
|
testOverfitting = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.CG;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Model getPretrainedModel() throws Exception {
|
public Model getPretrainedModel() throws Exception {
|
||||||
VGG16 vgg16 = VGG16.builder()
|
VGG16 vgg16 = VGG16.builder()
|
||||||
|
@ -269,6 +283,11 @@ public class CNN2DTestCases {
|
||||||
testOverfitting = false;
|
testOverfitting = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.CG;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Model getPretrainedModel() throws Exception {
|
public Model getPretrainedModel() throws Exception {
|
||||||
int nClasses = 10;
|
int nClasses = 10;
|
||||||
|
@ -372,6 +391,11 @@ public class CNN2DTestCases {
|
||||||
testOverfitting = true;
|
testOverfitting = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.CG;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Model getPretrainedModel() throws Exception {
|
public Model getPretrainedModel() throws Exception {
|
||||||
|
|
||||||
|
@ -381,6 +405,7 @@ public class CNN2DTestCases {
|
||||||
lrSchedule.put(3000, 0.001);
|
lrSchedule.put(3000, 0.001);
|
||||||
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.dataType(DataType.FLOAT)
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.l2(0.0005)
|
.l2(0.0005)
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -14,35 +15,31 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.integration.testcases;
|
package org.deeplearning4j.integration.testcases.dl4j;
|
||||||
|
|
||||||
import org.apache.commons.math3.stat.inference.TestUtils;
|
|
||||||
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
|
||||||
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
|
||||||
import org.deeplearning4j.eval.Evaluation;
|
import org.deeplearning4j.integration.ModelType;
|
||||||
import org.deeplearning4j.eval.IEvaluation;
|
|
||||||
import org.deeplearning4j.eval.ROCMultiClass;
|
|
||||||
import org.deeplearning4j.integration.TestCase;
|
import org.deeplearning4j.integration.TestCase;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.Convolution3D;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.PoolingType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Subsampling3DLayer;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.config.Nesterovs;
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -66,6 +63,11 @@ public class CNN3DTestCases {
|
||||||
testOverfitting = false;
|
testOverfitting = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.MLN;
|
||||||
|
}
|
||||||
|
|
||||||
public Object getConfiguration() throws Exception {
|
public Object getConfiguration() throws Exception {
|
||||||
int nChannels = 3; // Number of input channels
|
int nChannels = 3; // Number of input channels
|
||||||
int outputNum = 10; // The number of possible outcomes
|
int outputNum = 10; // The number of possible outcomes
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -14,8 +15,9 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.integration.testcases;
|
package org.deeplearning4j.integration.testcases.dl4j;
|
||||||
|
|
||||||
|
import org.deeplearning4j.integration.ModelType;
|
||||||
import org.deeplearning4j.integration.TestCase;
|
import org.deeplearning4j.integration.TestCase;
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
import org.datavec.api.records.reader.RecordReader;
|
||||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||||
|
@ -24,10 +26,6 @@ import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||||
import org.deeplearning4j.eval.Evaluation;
|
|
||||||
import org.deeplearning4j.eval.EvaluationCalibration;
|
|
||||||
import org.deeplearning4j.eval.IEvaluation;
|
|
||||||
import org.deeplearning4j.eval.ROCMultiClass;
|
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
@ -35,7 +33,12 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
import org.deeplearning4j.nn.graph.util.ComputationGraphUtil;
|
import org.deeplearning4j.nn.graph.util.ComputationGraphUtil;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||||
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
@ -76,9 +79,15 @@ public class MLPTestCases {
|
||||||
minAbsErrorOverfit = 1e-2;
|
minAbsErrorOverfit = 1e-2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.MLN;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Object getConfiguration() {
|
public Object getConfiguration() {
|
||||||
return new NeuralNetConfiguration.Builder()
|
return new NeuralNetConfiguration.Builder()
|
||||||
|
.dataType(DataType.FLOAT)
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.updater(new Adam(new MapSchedule.Builder(ScheduleType.ITERATION)
|
.updater(new Adam(new MapSchedule.Builder(ScheduleType.ITERATION)
|
||||||
.add(0, 5e-2)
|
.add(0, 5e-2)
|
||||||
|
@ -168,6 +177,11 @@ public class MLPTestCases {
|
||||||
testOverfitting = false; //Not much point here: very simple training data
|
testOverfitting = false; //Not much point here: very simple training data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.MLN;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Object getConfiguration() {
|
public Object getConfiguration() {
|
||||||
int seed = 123;
|
int seed = 123;
|
||||||
|
@ -179,6 +193,7 @@ public class MLPTestCases {
|
||||||
|
|
||||||
//log.info("Build model....");
|
//log.info("Build model....");
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.dataType(DataType.FLOAT)
|
||||||
.seed(seed)
|
.seed(seed)
|
||||||
.updater(new Nesterovs(learningRate, 0.9))
|
.updater(new Nesterovs(learningRate, 0.9))
|
||||||
.list()
|
.list()
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -14,22 +15,24 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.integration.testcases;
|
package org.deeplearning4j.integration.testcases.dl4j;
|
||||||
|
|
||||||
|
import org.deeplearning4j.integration.ModelType;
|
||||||
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||||
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor;
|
||||||
import org.nd4j.shade.guava.io.Files;
|
import org.nd4j.shade.guava.io.Files;
|
||||||
import org.deeplearning4j.integration.TestCase;
|
import org.deeplearning4j.integration.TestCase;
|
||||||
import org.deeplearning4j.integration.testcases.misc.CharacterIterator;
|
import org.deeplearning4j.integration.testcases.dl4j.misc.CharacterIterator;
|
||||||
import org.deeplearning4j.integration.testcases.misc.CompositeMultiDataSetPreProcessor;
|
|
||||||
import org.datavec.api.records.reader.SequenceRecordReader;
|
import org.datavec.api.records.reader.SequenceRecordReader;
|
||||||
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
||||||
import org.datavec.api.split.NumberedFileInputSplit;
|
import org.datavec.api.split.NumberedFileInputSplit;
|
||||||
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
|
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||||
import org.deeplearning4j.eval.Evaluation;
|
|
||||||
import org.deeplearning4j.eval.EvaluationCalibration;
|
|
||||||
import org.deeplearning4j.eval.IEvaluation;
|
|
||||||
import org.deeplearning4j.eval.ROCMultiClass;
|
|
||||||
import org.deeplearning4j.nn.conf.BackpropType;
|
import org.deeplearning4j.nn.conf.BackpropType;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
@ -91,6 +94,11 @@ public class RNNTestCases {
|
||||||
private int exampleLength = 1000;
|
private int exampleLength = 1000;
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.MLN;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Object getConfiguration() throws Exception {
|
public Object getConfiguration() throws Exception {
|
||||||
|
|
||||||
|
@ -101,6 +109,7 @@ public class RNNTestCases {
|
||||||
int tbpttLength = 50; //Length for truncated backpropagation through time. i.e., do parameter updates ever 50 characters
|
int tbpttLength = 50; //Length for truncated backpropagation through time. i.e., do parameter updates ever 50 characters
|
||||||
|
|
||||||
return new NeuralNetConfiguration.Builder()
|
return new NeuralNetConfiguration.Builder()
|
||||||
|
.dataType(DataType.FLOAT)
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.l2(0.001)
|
.l2(0.001)
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
@ -175,9 +184,15 @@ public class RNNTestCases {
|
||||||
return normalizer;
|
return normalizer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.MLN;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Object getConfiguration() throws Exception {
|
public Object getConfiguration() throws Exception {
|
||||||
return new NeuralNetConfiguration.Builder()
|
return new NeuralNetConfiguration.Builder()
|
||||||
|
.dataType(DataType.FLOAT)
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.updater(new Adam(5e-2))
|
.updater(new Adam(5e-2))
|
||||||
.l1(1e-3).l2(1e-3)
|
.l1(1e-3).l2(1e-3)
|
||||||
|
@ -298,6 +313,7 @@ public class RNNTestCases {
|
||||||
@Override
|
@Override
|
||||||
public Object getConfiguration() throws Exception {
|
public Object getConfiguration() throws Exception {
|
||||||
return new NeuralNetConfiguration.Builder()
|
return new NeuralNetConfiguration.Builder()
|
||||||
|
.dataType(DataType.FLOAT)
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.updater(new Adam(5e-2))
|
.updater(new Adam(5e-2))
|
||||||
.l1(1e-3).l2(1e-3)
|
.l1(1e-3).l2(1e-3)
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -14,18 +15,20 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.integration.testcases;
|
package org.deeplearning4j.integration.testcases.dl4j;
|
||||||
|
|
||||||
|
|
||||||
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||||
|
import org.deeplearning4j.integration.ModelType;
|
||||||
import org.deeplearning4j.integration.TestCase;
|
import org.deeplearning4j.integration.TestCase;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
|
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
|
||||||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
|
@ -59,9 +62,15 @@ public class UnsupervisedTestCases {
|
||||||
minAbsErrorPretrainParams = 5e-4;
|
minAbsErrorPretrainParams = 5e-4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.MLN;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Object getConfiguration() {
|
public Object getConfiguration() {
|
||||||
return new NeuralNetConfiguration.Builder()
|
return new NeuralNetConfiguration.Builder()
|
||||||
|
.dataType(DataType.FLOAT)
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.updater(new Adam(0.05))
|
.updater(new Adam(0.05))
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
|
@ -14,7 +14,7 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.integration.testcases.misc;
|
package org.deeplearning4j.integration.testcases.dl4j.misc;
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
@ -1,36 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.integration.testcases.misc;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
|
|
||||||
|
|
||||||
public class CompositeMultiDataSetPreProcessor implements MultiDataSetPreProcessor {
|
|
||||||
|
|
||||||
private MultiDataSetPreProcessor[] preProcessors;
|
|
||||||
|
|
||||||
public CompositeMultiDataSetPreProcessor(MultiDataSetPreProcessor... preProcessors){
|
|
||||||
this.preProcessors = preProcessors;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void preProcess(MultiDataSet multiDataSet) {
|
|
||||||
for(MultiDataSetPreProcessor p : preProcessors){
|
|
||||||
p.preProcess(multiDataSet);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,155 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.integration.testcases.samediff;
|
||||||
|
|
||||||
|
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||||
|
import org.deeplearning4j.integration.ModelType;
|
||||||
|
import org.deeplearning4j.integration.TestCase;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||||
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
public class SameDiffMLPTestCases {
|
||||||
|
|
||||||
|
|
||||||
|
public static TestCase getMLPMnist(){
|
||||||
|
return new TestCase() {
|
||||||
|
{
|
||||||
|
testName = "MLPMnistSD";
|
||||||
|
testType = TestType.RANDOM_INIT;
|
||||||
|
testPredictions = true;
|
||||||
|
testTrainingCurves = true;
|
||||||
|
testGradients = true;
|
||||||
|
testParamsPostTraining = true;
|
||||||
|
testEvaluation = true;
|
||||||
|
testOverfitting = true;
|
||||||
|
maxRelativeErrorOverfit = 2e-2;
|
||||||
|
minAbsErrorOverfit = 1e-2;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.SAMEDIFF;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object getConfiguration() throws Exception {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
//Define the network structure:
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784);
|
||||||
|
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 10);
|
||||||
|
|
||||||
|
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 784, 256));
|
||||||
|
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 256));
|
||||||
|
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 256, 10));
|
||||||
|
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10));
|
||||||
|
|
||||||
|
SDVariable a0 = sd.nn.tanh(in.mmul(w0).add(b0));
|
||||||
|
SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1));
|
||||||
|
SDVariable loss = sd.loss.logLoss("loss", label, out);
|
||||||
|
|
||||||
|
//Also set the training configuration:
|
||||||
|
sd.setTrainingConfig(TrainingConfig.builder()
|
||||||
|
.updater(new Adam(0.01))
|
||||||
|
.weightDecay(1e-3, true)
|
||||||
|
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
|
||||||
|
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
|
||||||
|
.build());
|
||||||
|
|
||||||
|
return sd;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
|
||||||
|
List<Map<String,INDArray>> out = new ArrayList<>();
|
||||||
|
|
||||||
|
DataSetIterator iter = new MnistDataSetIterator(1, true, 12345);
|
||||||
|
out.add(Collections.singletonMap("in", iter.next().getFeatures()));
|
||||||
|
|
||||||
|
iter = new MnistDataSetIterator(8, true, 12345);
|
||||||
|
out.add(Collections.singletonMap("in", iter.next().getFeatures()));
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> getPredictionsNamesSameDiff() throws Exception {
|
||||||
|
return Collections.singletonList("out");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, INDArray> getGradientsTestDataSameDiff() throws Exception {
|
||||||
|
DataSet ds = new MnistDataSetIterator(8, true, 12345).next();
|
||||||
|
Map<String,INDArray> map = new HashMap<>();
|
||||||
|
map.put("in", ds.getFeatures());
|
||||||
|
map.put("label", ds.getLabels());
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetIterator getTrainingData() throws Exception {
|
||||||
|
DataSetIterator iter = new MnistDataSetIterator(8, true, 12345);
|
||||||
|
iter = new EarlyTerminationDataSetIterator(iter, 32);
|
||||||
|
return new MultiDataSetIteratorAdapter(iter);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IEvaluation[] getNewEvaluations() {
|
||||||
|
return new IEvaluation[]{new Evaluation()};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetIterator getEvaluationTestData() throws Exception {
|
||||||
|
DataSetIterator iter = new MnistDataSetIterator(8, false, 12345);
|
||||||
|
iter = new EarlyTerminationDataSetIterator(iter, 32);
|
||||||
|
return new MultiDataSetIteratorAdapter(iter);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) {
|
||||||
|
sd.evaluate(iter, "out", 0, evaluations);
|
||||||
|
return evaluations;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSet getOverfittingData() throws Exception {
|
||||||
|
return new MnistDataSetIterator(1, true, 12345).next().toMultiDataSet();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getOverfitNumIterations() {
|
||||||
|
return 100;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,17 +1,23 @@
|
||||||
cmake_minimum_required(VERSION 3.15)
|
cmake_minimum_required(VERSION 3.15)
|
||||||
project(libnd4j)
|
project(libnd4j)
|
||||||
set(CMAKE_VERBOSE_MAKEFILE OFF)
|
set(CMAKE_VERBOSE_MAKEFILE OFF)
|
||||||
option(NATIVE "Optimize for build machine (might not work on others)" OFF)
|
|
||||||
set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH})
|
set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH})
|
||||||
#ensure we create lib files
|
#ensure we create lib files
|
||||||
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS OFF)
|
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS OFF)
|
||||||
option(CHECK_VECTORIZATION "checks for vectorization" OFF)
|
|
||||||
option(BUILD_TESTS "Build tests" OFF)
|
option(SD_NATIVE "Optimize for build machine (might not work on others)" OFF)
|
||||||
|
option(SD_CHECK_VECTORIZATION "checks for vectorization" OFF)
|
||||||
|
option(SD_BUILD_TESTS "Build tests" OFF)
|
||||||
|
option(SD_STATIC_LIB "Build static library" OFF)
|
||||||
|
option(SD_SHARED_LIB "Build shared library" ON)
|
||||||
|
option(SD_SANITIZE "Enable Address Sanitizer" ON)
|
||||||
|
|
||||||
option(FLATBUFFERS_BUILD_FLATC "Enable the build of the flatbuffers compiler" OFF)
|
option(FLATBUFFERS_BUILD_FLATC "Enable the build of the flatbuffers compiler" OFF)
|
||||||
set(FLATBUFFERS_BUILD_FLATC "OFF" CACHE STRING "Hack to disable flatc build" FORCE)
|
set(FLATBUFFERS_BUILD_FLATC "OFF" CACHE STRING "Hack to disable flatc build" FORCE)
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 11)
|
set(CMAKE_CXX_STANDARD 11)
|
||||||
if (CUDA_BLAS)
|
if (SD_CUDA)
|
||||||
enable_language(CUDA)
|
enable_language(CUDA)
|
||||||
set(CMAKE_CUDA_STANDARD 11)
|
set(CMAKE_CUDA_STANDARD 11)
|
||||||
|
|
||||||
|
@ -23,23 +29,23 @@ endif()
|
||||||
# MSVC runtime lib can be either "MultiThreaded" or "MultiThreadedDLL", /MT and /MD respectively
|
# MSVC runtime lib can be either "MultiThreaded" or "MultiThreadedDLL", /MT and /MD respectively
|
||||||
set(MSVC_RT_LIB "MultiThreadedDLL")
|
set(MSVC_RT_LIB "MultiThreadedDLL")
|
||||||
|
|
||||||
set(X86_BUILD false)
|
set(SD_X86_BUILD false)
|
||||||
|
|
||||||
if (NOT IOS_BUILD AND NOT ANDROID_BUILD AND NOT ${ARCH} MATCHES "power*" AND NOT ${ARCH} MATCHES "arm*")
|
if (NOT SD_IOS_BUILD AND NOT SD_ANDROID_BUILD AND NOT ${SD_ARCH} MATCHES "power*" AND NOT ${SD_ARCH} MATCHES "arm*")
|
||||||
set(X86_BUILD true)
|
set(SD_X86_BUILD true)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# -fsanitize=address
|
# -fsanitize=address
|
||||||
# -fsanitize=leak
|
# -fsanitize=leak
|
||||||
if (ANDROID_BUILD)
|
if (SD_ANDROID_BUILD)
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3 -fPIC -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D_RELEASE=true")
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3 -fPIC -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D_RELEASE=true")
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g -fPIC -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else")
|
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g -fPIC -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else")
|
||||||
elseif (APPLE)
|
elseif (APPLE)
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true -D_RELEASE=true")
|
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true -D_RELEASE=true")
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG " -O0 -g -fPIC -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true")
|
set(CMAKE_CXX_FLAGS_DEBUG " -O0 -g -fPIC -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true")
|
||||||
elseif(WIN32)
|
elseif(WIN32)
|
||||||
set(X86_BUILD true)
|
set(SD_X86_BUILD true)
|
||||||
if (CUDA_BLAS)
|
if (SD_CUDA)
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "-D_RELEASE=true")
|
set(CMAKE_CXX_FLAGS_RELEASE "-D_RELEASE=true")
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc")
|
set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc")
|
||||||
else()
|
else()
|
||||||
|
@ -50,14 +56,14 @@ else()
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -fmax-errors=2 -D_RELEASE=true")
|
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -fmax-errors=2 -D_RELEASE=true")
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG " -g -O0 -fPIC -fmax-errors=2")
|
set(CMAKE_CXX_FLAGS_DEBUG " -g -O0 -fPIC -fmax-errors=2")
|
||||||
|
|
||||||
if (CPU_BLAS)
|
if (SD_CPU)
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address")
|
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(NATIVE)
|
if(SD_NATIVE)
|
||||||
IF(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*")
|
IF(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*")
|
||||||
set(X86_BUILD false)
|
set(SD_X86_BUILD false)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mcpu=native")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mcpu=native")
|
||||||
ELSE()
|
ELSE()
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native")
|
||||||
|
@ -65,14 +71,13 @@ if(NATIVE)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
if(NOT CUDA_BLAS)
|
if(NOT SD_CUDA)
|
||||||
# we need this definition to avoid global memory use within mkldnn
|
# we need this definition to avoid global memory use within mkldnn
|
||||||
add_definitions(-DDNNL_ENABLE_CONCURRENT_EXEC=true)
|
add_definitions(-DDNNL_ENABLE_CONCURRENT_EXEC=true)
|
||||||
|
|
||||||
# there's a chance, we have no BLAS provided externally
|
# there's a chance, we have no BLAS provided externally
|
||||||
if ("${OPENBLAS_PATH}" STREQUAL "")
|
if ("${OPENBLAS_PATH}" STREQUAL "")
|
||||||
#we don't want static OpenBLAS on Apple
|
#we don't want OpenBLAS on Apple
|
||||||
set(BLA_STATIC ON)
|
|
||||||
if (NOT APPLE)
|
if (NOT APPLE)
|
||||||
set(BLA_VENDOR "OpenBLAS")
|
set(BLA_VENDOR "OpenBLAS")
|
||||||
endif()
|
endif()
|
||||||
|
@ -80,23 +85,8 @@ if(NOT CUDA_BLAS)
|
||||||
# look around for system blas instead
|
# look around for system blas instead
|
||||||
find_package(BLAS REQUIRED)
|
find_package(BLAS REQUIRED)
|
||||||
if (BLAS_FOUND)
|
if (BLAS_FOUND)
|
||||||
message("Original library: ${BLAS_LIBRARIES}")
|
|
||||||
# workaround for for cmake being unable to find static blas library
|
|
||||||
SET(_TMP_B "")
|
|
||||||
if (APPLE)
|
|
||||||
string(REGEX REPLACE "\\.dylib$" ".lib" _TMP_B "${BLAS_LIBRARIES}")
|
|
||||||
elseif (WIN32)
|
|
||||||
string(REGEX REPLACE "\\.dll" ".lib" _TMP_B "${BLAS_LIBRARIES}")
|
|
||||||
else()
|
|
||||||
string(REGEX REPLACE "\\.so$" ".a" _TMP_B "${BLAS_LIBRARIES}")
|
|
||||||
endif()
|
|
||||||
set(BLAS_LIBRARIES "${_TMP_B}")
|
|
||||||
|
|
||||||
message("Found external BLAS implementation: ${BLAS_LIBRARIES} ")
|
message("Found external BLAS implementation: ${BLAS_LIBRARIES} ")
|
||||||
add_definitions(-D__EXTERNAL_BLAS__=true)
|
add_definitions(-D__EXTERNAL_BLAS__=true)
|
||||||
elseif(WIN32)
|
|
||||||
message("BLAS not found, using downloaded OpenBLAS instead")
|
|
||||||
add_definitions(-D__EXTERNAL_BLAS__=true)
|
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
# if we have externally provided OPENBLAS_PATH - let's use it
|
# if we have externally provided OPENBLAS_PATH - let's use it
|
||||||
|
@ -107,7 +97,7 @@ if(NOT CUDA_BLAS)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# building cpu_features
|
# building cpu_features
|
||||||
if (X86_BUILD)
|
if (SD_X86_BUILD)
|
||||||
add_definitions(-DCPU_FEATURES=true)
|
add_definitions(-DCPU_FEATURES=true)
|
||||||
set(BUILD_PIC "ON" CACHE STRING "Hack to enforce fPIC mode" FORCE)
|
set(BUILD_PIC "ON" CACHE STRING "Hack to enforce fPIC mode" FORCE)
|
||||||
configure_file(./CMakeLists.txt.cpu_features.in cpu_features-download/CMakeLists.txt)
|
configure_file(./CMakeLists.txt.cpu_features.in cpu_features-download/CMakeLists.txt)
|
||||||
|
@ -169,7 +159,7 @@ endif()
|
||||||
|
|
||||||
|
|
||||||
if (${HELPERS_cudnn})
|
if (${HELPERS_cudnn})
|
||||||
if (NOT CUDA_BLAS)
|
if (NOT SD_CUDA)
|
||||||
message(FATAL_ERROR "Can't build cuDNN on non-CUDA platform")
|
message(FATAL_ERROR "Can't build cuDNN on non-CUDA platform")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -231,12 +221,12 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
|
||||||
|
|
||||||
if (NOT DEFINED ENV{CLION_IDE})
|
if (NOT DEFINED ENV{CLION_IDE})
|
||||||
message("NOT CLION")
|
message("NOT CLION")
|
||||||
include_directories(blas/ include/ include/helpers include/loops include/graph include/execution include/ops include/types include/array include/cnpy include/exceptions)
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||||
add_subdirectory(blas)
|
add_subdirectory(blas)
|
||||||
if(BUILD_TESTS)
|
if(SD_BUILD_TESTS)
|
||||||
# tests are always compiled with all ops included
|
# tests are always compiled with all ops included
|
||||||
set(LIBND4J_ALL_OPS true)
|
set(SD_ALL_OPS true)
|
||||||
set(LIBND4J_BUILD_MINIFIER true)
|
set(SD_BUILD_MINIFIER true)
|
||||||
add_subdirectory(tests_cpu)
|
add_subdirectory(tests_cpu)
|
||||||
endif()
|
endif()
|
||||||
endif ()
|
endif ()
|
||||||
|
@ -246,7 +236,7 @@ if ($ENV{CLION_IDE})
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
if (MSVC_DEV)
|
if (MSVC_DEV)
|
||||||
set(LIBND4J_BUILD_MINIFIER false)
|
set(SD_BUILD_MINIFIER false)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
set (CMAKE_INSTALL_PREFIX $ENV{ND4J_HOME}/nd4j-native-parent/nd4j-native/src/main/resources)
|
set (CMAKE_INSTALL_PREFIX $ENV{ND4J_HOME}/nd4j-native-parent/nd4j-native/src/main/resources)
|
||||||
|
|
|
@ -5,7 +5,7 @@ project(mkldnn-download NONE)
|
||||||
include(ExternalProject)
|
include(ExternalProject)
|
||||||
ExternalProject_Add(mkldnn
|
ExternalProject_Add(mkldnn
|
||||||
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
||||||
GIT_TAG v1.2
|
GIT_TAG v1.2.1
|
||||||
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
||||||
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
||||||
CONFIGURE_COMMAND ""
|
CONFIGURE_COMMAND ""
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
],
|
],
|
||||||
"buildRoot": "${env.USERPROFILE}\\CMakeBuilds\\${workspaceHash}\\build\\${name}",
|
"buildRoot": "${env.USERPROFILE}\\CMakeBuilds\\${workspaceHash}\\build\\${name}",
|
||||||
"installRoot": "${env.USERPROFILE}\\CMakeBuilds\\${workspaceHash}\\install\\${name}",
|
"installRoot": "${env.USERPROFILE}\\CMakeBuilds\\${workspaceHash}\\install\\${name}",
|
||||||
"cmakeCommandArgs": " -DCUDA_BLAS=true -DLIBND4J_NAME=nd4jcuda -DMSVC_DEV=true -DCOMPUTE=61 -DBUILD_TESTS=true",
|
"cmakeCommandArgs": " -DSD_CUDA=true -DLIBND4J_NAME=nd4jcuda -DMSVC_DEV=true -DCOMPUTE=61 -DBUILD_TESTS=true",
|
||||||
"buildCommandArgs": "-v",
|
"buildCommandArgs": "-v",
|
||||||
"ctestCommandArgs": ""
|
"ctestCommandArgs": ""
|
||||||
},
|
},
|
||||||
|
@ -20,7 +20,7 @@
|
||||||
"buildRoot": "${projectDir}\\out\\build\\${name}",
|
"buildRoot": "${projectDir}\\out\\build\\${name}",
|
||||||
"installRoot": "${projectDir}\\out\\install\\${name}",
|
"installRoot": "${projectDir}\\out\\install\\${name}",
|
||||||
"cmakeExecutable": "/usr/bin/cmake",
|
"cmakeExecutable": "/usr/bin/cmake",
|
||||||
"cmakeCommandArgs": "-DLIBND4J_ALL_OPS=true -DCMAKE_BUILD_TYPE=Debug -DCPU_BLAS=true -DLIBND4J_NAME=nd4jcpu -DBUILD_TESTS=ON -DCMAKE_BUILD_TYPE=Debug -DOPENBLAS_PATH=/usr/lib/openblas-base/ -DEXTENSION=avx2 ",
|
"cmakeCommandArgs": "-DSD_ALL_OPS=true -DCMAKE_BUILD_TYPE=Debug -DSD_CPU=true -DLIBND4J_NAME=nd4jcpu -DBUILD_TESTS=ON -DCMAKE_BUILD_TYPE=Debug -DOPENBLAS_PATH=/usr/lib/openblas-base/ -DEXTENSION=avx2 ",
|
||||||
"buildCommandArgs": "-j 4",
|
"buildCommandArgs": "-j 4",
|
||||||
"ctestCommandArgs": "",
|
"ctestCommandArgs": "",
|
||||||
"inheritEnvironments": [ "linux_x64" ],
|
"inheritEnvironments": [ "linux_x64" ],
|
||||||
|
|
|
@ -29,19 +29,24 @@ if(APPLE)
|
||||||
link_directories(/lib)
|
link_directories(/lib)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (APPLE_BUILD)
|
if (SD_APPLE_BUILD)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DAPPLE_BUILD=true -mmacosx-version-min=10.10")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSD_APPLE_BUILD=true -mmacosx-version-min=10.10")
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DAPPLE_BUILD=true -mmacosx-version-min=10.10")
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DSD_APPLE_BUILD=true -mmacosx-version-min=10.10")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (ANDROID_BUILD)
|
if (SD_ARM_BUILD)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DANDROID_BUILD=true")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSD_ARM_BUILD=true")
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DANDROID_BUILD=true")
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DSD_ARM_BUILD=true")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (IOS_BUILD)
|
if (SD_ANDROID_BUILD)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIOS_BUILD=true")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSD_ANDROID_BUILD=true")
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DIOS_BUILD=true")
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DSD_ANDROID_BUILD=true")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (SD_IOS_BUILD)
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSD_IOS_BUILD=true")
|
||||||
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DSD_IOS_BUILD=true")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
|
@ -63,33 +68,33 @@ if(WIN32)
|
||||||
SET(CMAKE_NINJA_FORCE_RESPONSE_FILE 1 CACHE INTERNAL "")
|
SET(CMAKE_NINJA_FORCE_RESPONSE_FILE 1 CACHE INTERNAL "")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if ("${LIBND4J_ALL_OPS}")
|
if ("${SD_ALL_OPS}")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DLIBND4J_ALL_OPS=true")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSD_ALL_OPS=true")
|
||||||
else()
|
else()
|
||||||
message("_OPS: ${LIBND4J_OPS_LIST}")
|
message("_OPS: ${SD_OPS_LIST}")
|
||||||
foreach(OP "${LIBND4J_OPS_LIST}")
|
foreach(OP "${SD_OPS_LIST}")
|
||||||
message(STATUS "${OP}")
|
message(STATUS "${OP}")
|
||||||
endforeach()
|
endforeach()
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${LIBND4J_OPS_LIST}")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SD_OPS_LIST}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
IF(${ARCH} MATCHES "arm*")
|
IF(${SD_ARCH} MATCHES "arm*")
|
||||||
set(ARCH_TUNE "-march=${ARCH}")
|
set(ARCH_TUNE "-march=${SD_ARCH}")
|
||||||
ELSEIF(${ARCH} MATCHES "power*")
|
ELSEIF(${SD_ARCH} MATCHES "power*")
|
||||||
set(ARCH_TUNE "-mcpu=${ARCH} -mtune=${ARCH} -D__POWER")
|
set(ARCH_TUNE "-mcpu=${SD_ARCH} -mtune=${SD_ARCH} -D__POWER")
|
||||||
ELSEIF(${EXTENSION} MATCHES "avx2")
|
ELSEIF(${SD_EXTENSION} MATCHES "avx2")
|
||||||
message("Building AVX2 binary...")
|
message("Building AVX2 binary...")
|
||||||
set(ARCH_TUNE "-mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mprefetchwt1 -DSD_F16C=true -DF_AVX2=true")
|
set(ARCH_TUNE "-mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mprefetchwt1 -DSD_F16C=true -DF_AVX2=true")
|
||||||
ELSE()
|
ELSE()
|
||||||
if ("${ARCH}" STREQUAL "x86-64")
|
if ("${SD_ARCH}" STREQUAL "x86-64")
|
||||||
message("Building x86_64 binary...")
|
message("Building x86_64 binary...")
|
||||||
set(ARCH_TYPE "generic")
|
set(ARCH_TYPE "generic")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DF_X64=true")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DF_X64=true")
|
||||||
else()
|
else()
|
||||||
set(ARCH_TYPE "${ARCH}")
|
set(ARCH_TYPE "${SD_ARCH}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
IF(${EXTENSION} MATCHES "avx512")
|
IF(${SD_EXTENSION} MATCHES "avx512")
|
||||||
message("Building AVX512 binary...")
|
message("Building AVX512 binary...")
|
||||||
# we need to set flag here, that we can use hardware f16 conversion + tell that cpu features should be tracked
|
# we need to set flag here, that we can use hardware f16 conversion + tell that cpu features should be tracked
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mavx512f -mavx512vl -mavx512bw -mavx512dq -mavx512cd -mbmi -mbmi2 -mprefetchwt1 -mclflushopt -mxsavec -mxsaves -DSD_F16C=true -DF_AVX512=true")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mavx512f -mavx512vl -mavx512bw -mavx512dq -mavx512cd -mbmi -mbmi2 -mprefetchwt1 -mclflushopt -mxsavec -mxsaves -DSD_F16C=true -DF_AVX512=true")
|
||||||
|
@ -97,11 +102,11 @@ ELSE()
|
||||||
|
|
||||||
if (NOT WIN32)
|
if (NOT WIN32)
|
||||||
# we don't want this definition for msvc
|
# we don't want this definition for msvc
|
||||||
set(ARCH_TUNE "-march=${ARCH} -mtune=${ARCH_TYPE}")
|
set(ARCH_TUNE "-march=${SD_ARCH} -mtune=${ARCH_TYPE}")
|
||||||
endif()
|
endif()
|
||||||
ENDIF()
|
ENDIF()
|
||||||
|
|
||||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang" AND X86_BUILD)
|
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang" AND SD_X86_BUILD)
|
||||||
# apple clang but not ios-arm
|
# apple clang but not ios-arm
|
||||||
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}")
|
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}")
|
||||||
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
|
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
|
||||||
|
@ -124,10 +129,10 @@ IF(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
|
||||||
include_directories("/usr/include")
|
include_directories("/usr/include")
|
||||||
include_directories("/usr/local/include")
|
include_directories("/usr/local/include")
|
||||||
ENDIF(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
|
ENDIF(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
|
||||||
if(!CUDA_BLAS)
|
if(!SD_CUDA)
|
||||||
if(!CPU_BLAS)
|
if(!SD_CPU)
|
||||||
set(CUDA_BLAS FALSE)
|
set(SD_CUDA FALSE)
|
||||||
set(CPU_BLAS TRUE)
|
set(SD_CPU TRUE)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -136,7 +141,7 @@ if (HAVE_MKLDNN)
|
||||||
file(GLOB_RECURSE CUSTOMOPS_MKLDNN_SOURCES false ../include/ops/declarable/platform/mkldnn/*.cpp ../include/ops/declarable/platform/mkldnn/mkldnnUtils.h)
|
file(GLOB_RECURSE CUSTOMOPS_MKLDNN_SOURCES false ../include/ops/declarable/platform/mkldnn/*.cpp ../include/ops/declarable/platform/mkldnn/mkldnnUtils.h)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(CUDA_BLAS)
|
if(SD_CUDA)
|
||||||
message("Build cublas")
|
message("Build cublas")
|
||||||
find_package(CUDA)
|
find_package(CUDA)
|
||||||
add_definitions(-D__CUDABLAS__=true)
|
add_definitions(-D__CUDABLAS__=true)
|
||||||
|
@ -149,7 +154,7 @@ if(CUDA_BLAS)
|
||||||
include_directories(${CUDA_INCLUDE_DIRS})
|
include_directories(${CUDA_INCLUDE_DIRS})
|
||||||
message("CUDA found!")
|
message("CUDA found!")
|
||||||
|
|
||||||
if ("${EXPERIMENTAL}" STREQUAL "yes")
|
if ("${SD_EXPERIMENTAL}" STREQUAL "yes")
|
||||||
message("Experimental mode ENABLED")
|
message("Experimental mode ENABLED")
|
||||||
set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -D__ND4J_EXPERIMENTAL__=true")
|
set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -D__ND4J_EXPERIMENTAL__=true")
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__ND4J_EXPERIMENTAL__=true")
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__ND4J_EXPERIMENTAL__=true")
|
||||||
|
@ -213,6 +218,7 @@ if(CUDA_BLAS)
|
||||||
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/impl/*.cpp ../include/helpers/*.cu ../include/helpers/*.cupp ../include/helpers/*.h)
|
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/impl/*.cpp ../include/helpers/*.cu ../include/helpers/*.cupp ../include/helpers/*.h)
|
||||||
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
|
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
|
||||||
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/impl/*.cpp ../include/loops/*.h)
|
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/impl/*.cpp ../include/loops/*.h)
|
||||||
|
file(GLOB_RECURSE LEGACY_SOURCES false ../include/legacy/impl/*.cpp ../include/legacy/*.cu ../include/legacy/*.h)
|
||||||
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
|
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
|
||||||
|
|
||||||
if (HAVE_CUDNN)
|
if (HAVE_CUDNN)
|
||||||
|
@ -220,43 +226,41 @@ if(CUDA_BLAS)
|
||||||
file(GLOB_RECURSE CUSTOMOPS_CUDNN_SOURCES false ../include/ops/declarable/platform/cudnn/*.cu)
|
file(GLOB_RECURSE CUSTOMOPS_CUDNN_SOURCES false ../include/ops/declarable/platform/cudnn/*.cu)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_library(nd4jobj OBJECT cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA}
|
add_library(nd4jobj OBJECT ${LOOPS_SOURCES_CUDA} ${LEGACY_SOURCES}
|
||||||
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
||||||
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
|
${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||||
cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
|
|
||||||
Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
|
||||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES})
|
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES})
|
||||||
|
|
||||||
add_library(${LIBND4J_NAME} SHARED $<TARGET_OBJECTS:nd4jobj>)
|
add_library(${SD_LIBRARY_NAME} SHARED $<TARGET_OBJECTS:nd4jobj>)
|
||||||
|
|
||||||
if (WIN32)
|
if (WIN32)
|
||||||
message("MSVC runtime for library: ${MSVC_RT_LIB}")
|
message("MSVC runtime for library: ${MSVC_RT_LIB}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# static library is built only if we're going to build tests, skip otherwise
|
# static library is built only if we're going to build tests, skip otherwise
|
||||||
if (BUILD_TESTS)
|
if (SD_BUILD_TESTS OR SD_STATIC_LIB)
|
||||||
add_library(${LIBND4J_NAME}static STATIC $<TARGET_OBJECTS:nd4jobj>)
|
add_library(${SD_LIBRARY_NAME}static STATIC $<TARGET_OBJECTS:nd4jobj>)
|
||||||
set_property(TARGET ${LIBND4J_NAME}static PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
set_property(TARGET ${SD_LIBRARY_NAME}static PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
||||||
install(TARGETS ${LIBND4J_NAME}static DESTINATION .)
|
install(TARGETS ${SD_LIBRARY_NAME}static DESTINATION .)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# on windows we want to make sure we use MT or MD, but since we use it in one lib, we must use it everywhere to avoid conflicts
|
# on windows we want to make sure we use MT or MD, but since we use it in one lib, we must use it everywhere to avoid conflicts
|
||||||
set_property(TARGET nd4jobj PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
set_property(TARGET nd4jobj PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
||||||
set_property(TARGET ${LIBND4J_NAME} PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
set_property(TARGET ${SD_LIBRARY_NAME} PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
||||||
|
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
message("CUDA on Windows: enabling /EHsc")
|
message("CUDA on Windows: enabling /EHsc")
|
||||||
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /bigobj /std:c++14")
|
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /bigobj /std:c++14")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
target_link_libraries(${LIBND4J_NAME} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY} ${CUDNN} ${MKLDNN})
|
target_link_libraries(${SD_LIBRARY_NAME} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY} ${CUDNN} ${MKLDNN})
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cuda)
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cuda)
|
||||||
|
|
||||||
install(TARGETS ${LIBND4J_NAME} DESTINATION .)
|
install(TARGETS ${SD_LIBRARY_NAME} DESTINATION .)
|
||||||
endif(CUDA_FOUND)
|
endif(CUDA_FOUND)
|
||||||
elseif(CPU_BLAS)
|
elseif(SD_CPU)
|
||||||
|
|
||||||
if ("${EXPERIMENTAL}" STREQUAL "yes")
|
if ("${SD_EXPERIMENTAL}" STREQUAL "yes")
|
||||||
message("Experimental mode ENABLED")
|
message("Experimental mode ENABLED")
|
||||||
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__ND4J_EXPERIMENTAL__=true")
|
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__ND4J_EXPERIMENTAL__=true")
|
||||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__ND4J_EXPERIMENTAL__=true")
|
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__ND4J_EXPERIMENTAL__=true")
|
||||||
|
@ -274,15 +278,16 @@ elseif(CPU_BLAS)
|
||||||
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
|
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
|
||||||
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
|
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
|
||||||
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/*.cpp ../include/helpers/*.h)
|
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/*.cpp ../include/helpers/*.h)
|
||||||
|
file(GLOB_RECURSE LEGACY_SOURCES false ../include/legacy/impl/*.cpp ../include/legacy/cpu/*.cpp ../include/legacy/*.h)
|
||||||
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
|
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
|
||||||
|
|
||||||
if (X86_BUILD)
|
if (SD_X86_BUILD)
|
||||||
# we disable platform optimizations for certains files for linux/macos
|
# we disable platform optimizations for certains files for linux/macos
|
||||||
set_source_files_properties(cpu/NativeOps.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
|
set_source_files_properties(cpu/NativeOps.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
|
||||||
set_source_files_properties(../include/helpers/impl/OpTracker.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
|
set_source_files_properties(../include/helpers/impl/OpTracker.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(CHECK_VECTORIZATION)
|
if(SD_CHECK_VECTORIZATION)
|
||||||
set(VECT_FILES cpu/NativeOps.cpp ${OPS_SOURCES} ${HELPERS_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${LOOPS_SOURCES})
|
set(VECT_FILES cpu/NativeOps.cpp ${OPS_SOURCES} ${HELPERS_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${LOOPS_SOURCES})
|
||||||
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||||
|
|
||||||
|
@ -310,33 +315,31 @@ elseif(CPU_BLAS)
|
||||||
|
|
||||||
message("CPU BLAS")
|
message("CPU BLAS")
|
||||||
add_definitions(-D__CPUBLAS__=true)
|
add_definitions(-D__CPUBLAS__=true)
|
||||||
add_library(nd4jobj OBJECT cpu/NativeOps.cpp cpu/GraphExecutioner.cpp
|
add_library(nd4jobj OBJECT ${LEGACY_SOURCES}
|
||||||
cpu/NativeOpExecutioner.cpp cpu/NDArray.cpp cpu/NDArrayFactory.cpp
|
${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||||
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
|
|
||||||
Environment.cpp Environment.h ${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
|
||||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
|
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
|
||||||
${OPS_SOURCES} ${PERF_SOURCES})
|
${OPS_SOURCES} ${PERF_SOURCES})
|
||||||
if(IOS)
|
if(IOS)
|
||||||
add_library(${LIBND4J_NAME} STATIC $<TARGET_OBJECTS:nd4jobj>)
|
add_library(${SD_LIBRARY_NAME} STATIC $<TARGET_OBJECTS:nd4jobj>)
|
||||||
else()
|
else()
|
||||||
# static library is built only if we're going to build tests, skip otherwise
|
# static library is built only if we're going to build tests, skip otherwise
|
||||||
if (BUILD_TESTS)
|
if (SD_BUILD_TESTS OR SD_STATIC_LIB)
|
||||||
add_library(${LIBND4J_NAME}static STATIC $<TARGET_OBJECTS:nd4jobj>)
|
add_library(${SD_LIBRARY_NAME}static STATIC $<TARGET_OBJECTS:nd4jobj>)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_library(${LIBND4J_NAME} SHARED $<TARGET_OBJECTS:nd4jobj>)
|
add_library(${SD_LIBRARY_NAME} SHARED $<TARGET_OBJECTS:nd4jobj>)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# we're including {MKLDNN} here in case of building from sources. in future that'll replace {MKLDNN_LIBRARIES}. same applies to BLAS
|
# we're including {MKLDNN} here in case of building from sources. in future that'll replace {MKLDNN_LIBRARIES}. same applies to BLAS
|
||||||
if (NOT BLAS_LIBRARIES)
|
if (NOT BLAS_LIBRARIES)
|
||||||
set(BLAS_LIBRARIES "")
|
set(BLAS_LIBRARIES "")
|
||||||
endif()
|
endif()
|
||||||
target_link_libraries(${LIBND4J_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
||||||
|
|
||||||
if ("${LIBND4J_ALL_OPS}" AND "${LIBND4J_BUILD_MINIFIER}")
|
if ("${SD_ALL_OPS}" AND "${SD_BUILD_MINIFIER}")
|
||||||
message(STATUS "Building minifier...")
|
message(STATUS "Building minifier...")
|
||||||
add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp)
|
add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp)
|
||||||
target_link_libraries(minifier ${LIBND4J_NAME}static ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
target_link_libraries(minifier ${SD_LIBRARY_NAME}static ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
|
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
|
||||||
|
@ -357,6 +360,6 @@ elseif(CPU_BLAS)
|
||||||
SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic")
|
SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
install(TARGETS ${LIBND4J_NAME} DESTINATION .)
|
install(TARGETS ${SD_LIBRARY_NAME} DESTINATION .)
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cpu)
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cpu)
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -1,191 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
* Copyright (c) 2019-2020 Konduit K.K.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// Created by raver119 on 2018-09-16.
|
|
||||||
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
|
||||||
//
|
|
||||||
|
|
||||||
#ifndef DEV_TESTS_NDARRAYFACTORY_H
|
|
||||||
#define DEV_TESTS_NDARRAYFACTORY_H
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <initializer_list>
|
|
||||||
#include <NDArray.h>
|
|
||||||
//#include <memory/Workspace.h>
|
|
||||||
#include <execution/LaunchContext.h>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
|
|
||||||
namespace nd4j {
|
|
||||||
class ND4J_EXPORT NDArrayFactory {
|
|
||||||
private:
|
|
||||||
template <typename T>
|
|
||||||
static void memcpyFromVector(void *ptr, const std::vector<T> &vector);
|
|
||||||
public:
|
|
||||||
template <typename T>
|
|
||||||
static NDArray* empty_(nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
static NDArray* empty_(nd4j::DataType dataType, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static NDArray empty(nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
static NDArray empty(nd4j::DataType dataType, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static NDArray* valueOf(const std::initializer_list<Nd4jLong>& shape, T value, char order = 'c', nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static NDArray* valueOf(const std::vector<Nd4jLong>& shape, T value, char order = 'c', nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
static NDArray* valueOf(const std::vector<Nd4jLong>& shape, const NDArray& value, char order = 'c', nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static NDArray* linspace(T from, T to, Nd4jLong numElements);
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static NDArray* create_(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
static NDArray* create_(nd4j::DataType dtype, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static NDArray create(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
static NDArray create(nd4j::DataType dtype, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
template <typename T>
|
|
||||||
static NDArray create(DataType type, const T scalar, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static NDArray* vector(Nd4jLong length, T startingValue = (T) 0, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static NDArray* create_(char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
static NDArray* create_( char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dataType, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static NDArray* create_(char order, const std::vector<Nd4jLong> &shape, const std::vector<T> &data, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static NDArray create(char order, const std::vector<Nd4jLong> &shape, const std::vector<T> &data, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static NDArray create(char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
static NDArray create(char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static NDArray create(const std::vector<T> &values, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
#ifndef __JAVACPP_HACK__
|
|
||||||
// this method only available out of javacpp
|
|
||||||
/**
|
|
||||||
* This constructor creates vector of T
|
|
||||||
*
|
|
||||||
* @param values
|
|
||||||
*/
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static NDArray create(char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static NDArray create(T* buffer, char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static NDArray create(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<T>& data, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method creates NDArray from .npy file
|
|
||||||
* @param fileName
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
static NDArray fromNpyFile(const char *fileName);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This factory create array from utf8 string
|
|
||||||
* @return NDArray default dataType UTF8
|
|
||||||
*/
|
|
||||||
static NDArray string(const char *string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
static NDArray* string_(const char *string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
static NDArray* string_(const std::string &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
static NDArray string(const std::string& string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This factory create array from utf16 string
|
|
||||||
* @return NDArray default dataType UTF16
|
|
||||||
*/
|
|
||||||
static NDArray string(const char16_t* u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray* string_(const char16_t* u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray* string_(const std::u16string& u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray string(const std::u16string& u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This factory create array from utf32 string
|
|
||||||
* @return NDArray default dataType UTF32
|
|
||||||
*/
|
|
||||||
static NDArray string(const char32_t* u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray* string_(const char32_t* u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray* string_(const std::u32string& u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray string(const std::u32string& u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This factory create array from vector of utf8 strings
|
|
||||||
* @return NDArray default dataType UTF8
|
|
||||||
*/
|
|
||||||
static NDArray string( const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
static NDArray string( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
static NDArray string( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
static NDArray string( const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This factory create array from vector of utf16 strings
|
|
||||||
* @return NDArray default dataType UTF16
|
|
||||||
*/
|
|
||||||
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This factory create array from vector of utf32 strings
|
|
||||||
* @return NDArray default dataType UTF32
|
|
||||||
*/
|
|
||||||
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
|
||||||
|
|
||||||
|
|
||||||
static ResultSet createSetOfArrs(const Nd4jLong numOfArrs, const void* buffer, const Nd4jLong* shapeInfo, const Nd4jLong* offsets, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
#endif
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif //DEV_TESTS_NDARRAYFACTORY_H
|
|
|
@ -1,148 +0,0 @@
|
||||||
################################################################################
|
|
||||||
# Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
#
|
|
||||||
# This program and the accompanying materials are made available under the
|
|
||||||
# terms of the Apache License, Version 2.0 which is available at
|
|
||||||
# https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
# License for the specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
################################################################################
|
|
||||||
|
|
||||||
#ifndef NDARRAY_MACRO
|
|
||||||
#define NDARRAY_MACRO
|
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
|
||||||
|
|
||||||
//NDArray<T> *other, T *extraParams
|
|
||||||
BUILD_CALL_1(template void NDArray<float>::template applyPairwiseTransform, float, (NDArray<float>* other, float* extraParams), PAIRWISE_TRANSFORM_OPS)
|
|
||||||
BUILD_CALL_1(template void NDArray<float16>::applyPairwiseTransform, float16, (NDArray<float16>* other, float16* extraParams), PAIRWISE_TRANSFORM_OPS)
|
|
||||||
BUILD_CALL_1(template void NDArray<double>::applyPairwiseTransform, double, (NDArray<double>* other, double* extraParams), PAIRWISE_TRANSFORM_OPS)
|
|
||||||
|
|
||||||
// NDArray<T> *other, NDArray<T> *target, T *extraParams
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float>::applyPairwiseTransform, float, (NDArray<float>* other, NDArray<float>* target, float* extraParams), PAIRWISE_TRANSFORM_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float16>::applyPairwiseTransform, float16, (NDArray<float16>* other, NDArray<float16>* target, float16* extraParams), PAIRWISE_TRANSFORM_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<double>::applyPairwiseTransform, double, (NDArray<double>* other, NDArray<double>* target, double* extraParams), PAIRWISE_TRANSFORM_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float16>::applyScalar, float16, (NDArray<float16>& scalar, NDArray<float16>* target, float16 *extraParams) const, SCALAR_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float16>::applyScalar, float16, (float16 scalar, NDArray<float16>* target, float16 *extraParams) const, SCALAR_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float>::applyScalar, float, (NDArray<float>& scalar, NDArray<float>* target, float *extraParams) const, SCALAR_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float>::applyScalar, float, (float scalar, NDArray<float>* target, float *extraParams) const, SCALAR_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<double>::applyScalar, double, (NDArray<double>& scalar, NDArray<double>* target, double *extraParams) const, SCALAR_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<double>::applyScalar, double, (double scalar, NDArray<double>* target, double *extraParams) const, SCALAR_OPS)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BUILD_CALL_1(template float16 nd4j::NDArray<float16>::reduceNumber, float16, (float16 *extraParams) const, REDUCE_OPS)
|
|
||||||
BUILD_CALL_1(template float nd4j::NDArray<float>::reduceNumber, float, (float *extraParams) const, REDUCE_OPS)
|
|
||||||
BUILD_CALL_1(template double nd4j::NDArray<double>::reduceNumber, double, (double *extraParams) const, REDUCE_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template Nd4jLong nd4j::NDArray<float16>::indexReduceNumber, float16, (float16 *extraParams), INDEX_REDUCE_OPS)
|
|
||||||
BUILD_CALL_1(template Nd4jLong nd4j::NDArray<float>::indexReduceNumber, float, (float *extraParams), INDEX_REDUCE_OPS)
|
|
||||||
BUILD_CALL_1(template Nd4jLong nd4j::NDArray<double>::indexReduceNumber, double, (double *extraParams), INDEX_REDUCE_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float16>::applyBroadcast, float16, (std::initializer_list<int> list, const nd4j::NDArray<float16>* a, nd4j::NDArray<float16>* b, float16* c), BROADCAST_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float>::applyBroadcast, float, (std::initializer_list<int> list, const nd4j::NDArray<float>* a, nd4j::NDArray<float>* b, float* c), BROADCAST_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<double>::applyBroadcast, double, (std::initializer_list<int> list, const nd4j::NDArray<double>* a, nd4j::NDArray<double>* b, double* c), BROADCAST_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float16>::applyTrueBroadcast, float16,(const nd4j::NDArray<float16>* a, nd4j::NDArray<float16>* target, const bool checkTargetShape, float16* c) const, BROADCAST_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float>::applyTrueBroadcast, float, (const nd4j::NDArray<float>* a, nd4j::NDArray<float>* target, const bool checkTargetShape, float* c) const, BROADCAST_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<double>::applyTrueBroadcast, double, (const nd4j::NDArray<double>* a, nd4j::NDArray<double>* target, const bool checkTargetShape, double* c) const, BROADCAST_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template nd4j::NDArray<float16>* nd4j::NDArray<float16>::applyTrueBroadcast, float16, (const nd4j::NDArray<float16>* a, float16* c) const, BROADCAST_OPS)
|
|
||||||
BUILD_CALL_1(template nd4j::NDArray<float>* nd4j::NDArray<float>::applyTrueBroadcast, float, (const nd4j::NDArray<float>* a, float* c) const, BROADCAST_OPS)
|
|
||||||
BUILD_CALL_1(template nd4j::NDArray<double>* nd4j::NDArray<double>::applyTrueBroadcast, double, (const nd4j::NDArray<double>* a, double* c) const, BROADCAST_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template nd4j::NDArray<float16> nd4j::NDArray<float16>::applyTrueBroadcast, float16, (const nd4j::NDArray<float16>& a, float16* c) const, BROADCAST_OPS)
|
|
||||||
BUILD_CALL_1(template nd4j::NDArray<float> nd4j::NDArray<float>::applyTrueBroadcast, float, (const nd4j::NDArray<float>& a, float* c) const, BROADCAST_OPS)
|
|
||||||
BUILD_CALL_1(template nd4j::NDArray<double> nd4j::NDArray<double>::applyTrueBroadcast, double, (const nd4j::NDArray<double>& a, double* c) const, BROADCAST_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float16>::applyTransform, float16, (NDArray<float16>* target, float16* extraParams), TRANSFORM_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float>::applyTransform, float, (NDArray<float>* target, float* extraParams), TRANSFORM_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<double>::applyTransform, double, (NDArray<double>* target, double* extraParams), TRANSFORM_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float16>::applyTransform, float16, (float16* extraParams), TRANSFORM_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float>::applyTransform, float, (float* extraParams), TRANSFORM_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<double>::applyTransform, double, (double* extraParams), TRANSFORM_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float16>::applyRandom, float16, (nd4j::random::RandomBuffer *buffer, NDArray<float16>* y, NDArray<float16>* z, float16* extraParams), RANDOM_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float>::applyRandom, float, (nd4j::random::RandomBuffer *buffer, NDArray<float>* y, NDArray<float>* z, float* extraParams), RANDOM_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<double>::applyRandom, double, (nd4j::random::RandomBuffer *buffer, NDArray<double>* y, NDArray<double>* z, double* extraParams), RANDOM_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template NDArray<float16> nd4j::NDArray<float16>::transform, float16, (float16* extraParams) const, TRANSFORM_OPS)
|
|
||||||
BUILD_CALL_1(template NDArray<float> nd4j::NDArray<float>::transform, float, (float* extraParams) const, TRANSFORM_OPS)
|
|
||||||
BUILD_CALL_1(template NDArray<double> nd4j::NDArray<double>::transform, double, (double* extraParams) const, TRANSFORM_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template NDArray<float> *nd4j::NDArray<float>::template reduceAlongDimension, float, (const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS)
|
|
||||||
BUILD_CALL_1(template NDArray<float16> *nd4j::NDArray<float16>::template reduceAlongDimension, float16, (const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS)
|
|
||||||
BUILD_CALL_1(template NDArray<double> *nd4j::NDArray<double>::template reduceAlongDimension, double, (const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template NDArray<float> nd4j::NDArray<float>::template reduceAlongDims, float, (const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS)
|
|
||||||
BUILD_CALL_1(template NDArray<float16> nd4j::NDArray<float16>::template reduceAlongDims, float16, (const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS)
|
|
||||||
BUILD_CALL_1(template NDArray<double> nd4j::NDArray<double>::template reduceAlongDims, double, (const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template NDArray<float> *nd4j::NDArray<float>::template reduceAlongDimension, float, (const std::initializer_list<int>& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS)
|
|
||||||
BUILD_CALL_1(template NDArray<float16> *nd4j::NDArray<float16>::template reduceAlongDimension, float16, (const std::initializer_list<int>& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS)
|
|
||||||
BUILD_CALL_1(template NDArray<double> *nd4j::NDArray<double>::template reduceAlongDimension, double, (const std::initializer_list<int>& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float>::template reduceAlongDimension, float, (NDArray<float>* target, const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes, float * extras) const, REDUCE_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float16>::template reduceAlongDimension, float16, (NDArray<float16>* target, const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes, float16 * extras) const, REDUCE_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<double>::template reduceAlongDimension, double, (NDArray<double>* target, const std::vector<int>& dimension, const bool keepDims, const bool supportOldShapes, double * extras) const, REDUCE_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template NDArray<float> *nd4j::NDArray<float>::template varianceAlongDimension, float, (const bool biasCorrected, const std::initializer_list<int>& dimensions) const, SUMMARY_STATS_OPS)
|
|
||||||
BUILD_CALL_1(template NDArray<float16> *nd4j::NDArray<float16>::template varianceAlongDimension, float16, (const bool biasCorrected, const std::initializer_list<int>& dimensions) const, SUMMARY_STATS_OPS)
|
|
||||||
BUILD_CALL_1(template NDArray<double> *nd4j::NDArray<double>::template varianceAlongDimension, double, (const bool biasCorrected, const std::initializer_list<int>& dimensions) const, SUMMARY_STATS_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float>::template varianceAlongDimension, float, (const NDArray<float> *target, const bool biasCorrected, const std::initializer_list<int>& dimensions), SUMMARY_STATS_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float16>::template varianceAlongDimension, float16, (const NDArray<float16> *target,const bool biasCorrected, const std::initializer_list<int>& dimensions), SUMMARY_STATS_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<double>::template varianceAlongDimension, double, (const NDArray<double> *target, const bool biasCorrected, const std::initializer_list<int>& dimensions), SUMMARY_STATS_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float>::template varianceAlongDimension, float, (const NDArray<float> *target, const bool biasCorrected, const std::vector<int>& dimensions), SUMMARY_STATS_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float16>::template varianceAlongDimension, float16, (const NDArray<float16> *target,const bool biasCorrected, const std::vector<int>& dimensions), SUMMARY_STATS_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<double>::template varianceAlongDimension, double, (const NDArray<double> *target, const bool biasCorrected, const std::vector<int>& dimensions), SUMMARY_STATS_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template float nd4j::NDArray<float>::template varianceNumber, float, (bool biasCorrected), SUMMARY_STATS_OPS)
|
|
||||||
BUILD_CALL_1(template float16 nd4j::NDArray<float16>::template varianceNumber, float16, (bool biasCorrected), SUMMARY_STATS_OPS)
|
|
||||||
BUILD_CALL_1(template double nd4j::NDArray<double>::template varianceNumber, double, (bool biasCorrected), SUMMARY_STATS_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template NDArray<float> *nd4j::NDArray<float>::template applyReduce3, float, (const NDArray<float>* other, const float* extraParams) const, REDUCE3_OPS)
|
|
||||||
BUILD_CALL_1(template NDArray<float16> *nd4j::NDArray<float16>::template applyReduce3, float16, (const NDArray<float16>* other, const float16* extraParams) const, REDUCE3_OPS)
|
|
||||||
BUILD_CALL_1(template NDArray<double> *nd4j::NDArray<double>::template applyReduce3, double, (const NDArray<double>* other, const double* extraParams) const, REDUCE3_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template NDArray<float> *nd4j::NDArray<float>::template applyReduce3, float, (const NDArray<float>* other, const std::vector<int> &dims, const float* extraParams) const, REDUCE3_OPS)
|
|
||||||
BUILD_CALL_1(template NDArray<float16> *nd4j::NDArray<float16>::template applyReduce3, float16, (const NDArray<float16>* other, const std::vector<int> &dims, const float16* extraParams) const, REDUCE3_OPS)
|
|
||||||
BUILD_CALL_1(template NDArray<double> *nd4j::NDArray<double>::template applyReduce3, double, (const NDArray<double>* other, const std::vector<int> &dims, const double* extraParams) const, REDUCE3_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float>::template applyIndexReduce, float, (const NDArray<float>* target, const std::vector<int> & alpha, const float* beta) const, INDEX_REDUCE_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<float16>::template applyIndexReduce, float16, (const NDArray<float16>* target, const std::vector<int> & alpha, const float16* beta) const, INDEX_REDUCE_OPS)
|
|
||||||
BUILD_CALL_1(template void nd4j::NDArray<double>::template applyIndexReduce, double, (const NDArray<double>* target, const std::vector<int> & alpha, const double* beta) const, INDEX_REDUCE_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template NDArray<float> *nd4j::NDArray<float>::template applyIndexReduce, float, (const std::vector<int> & alpha, const float* beta) const, INDEX_REDUCE_OPS)
|
|
||||||
BUILD_CALL_1(template NDArray<float16> *nd4j::NDArray<float16>::template applyIndexReduce, float16, (const std::vector<int> & alpha, const float16* beta) const, INDEX_REDUCE_OPS)
|
|
||||||
BUILD_CALL_1(template NDArray<double> *nd4j::NDArray<double>::template applyIndexReduce, double, (const std::vector<int> & alpha, const double* beta) const, INDEX_REDUCE_OPS)
|
|
||||||
|
|
||||||
BUILD_CALL_1(template NDArray<float> *nd4j::NDArray<float>::template applyAllReduce3, float, (const nd4j::NDArray<float>* alpha, const std::vector<int> & beta, float const* gamma) const, REDUCE3_OPS)
|
|
||||||
BUILD_CALL_1(template NDArray<float16> *nd4j::NDArray<float16>::template applyAllReduce3, float16, (const nd4j::NDArray<float16>* alpha, const std::vector<int> & beta, float16 const* gamma) const, REDUCE3_OPS)
|
|
||||||
BUILD_CALL_1(template NDArray<double> *nd4j::NDArray<double>::template applyAllReduce3, double, (const nd4j::NDArray<double>* alpha, const std::vector<int> & beta, double const* gamma) const, REDUCE3_OPS)
|
|
||||||
|
|
||||||
template NDArray<float> mmul(const NDArray<float>& left, const NDArray<float>& right);
|
|
||||||
template NDArray<float16> mmul(const NDArray<float16>& left, const NDArray<float16>& right);
|
|
||||||
template NDArray<double> mmul(const NDArray<double>& left, const NDArray<double>& right);
|
|
||||||
|
|
||||||
// template NDArray<float> operator-(const float, const NDArray<float>&);
|
|
||||||
// template NDArray<float16> operator-(const float16, const NDArray<float16>&);
|
|
||||||
// template NDArray<double> operator-(const double, const NDArray<double>&);
|
|
||||||
|
|
||||||
// template NDArray<float> operator+(const float, const NDArray<float>&);
|
|
||||||
// template NDArray<float16> operator+(const float16, const NDArray<float16>&);
|
|
||||||
// template NDArray<double> operator+(const double, const NDArray<double>&);
|
|
||||||
|
|
||||||
|
|
||||||
#endif
|
|
|
@ -173,7 +173,7 @@ fi
|
||||||
case "$OS" in
|
case "$OS" in
|
||||||
linux-armhf)
|
linux-armhf)
|
||||||
export RPI_BIN=$RPI_HOME/tools/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf/bin/arm-linux-gnueabihf
|
export RPI_BIN=$RPI_HOME/tools/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf/bin/arm-linux-gnueabihf
|
||||||
export CMAKE_COMMAND="$CMAKE_COMMAND -D CMAKE_TOOLCHAIN_FILE=cmake/rpi.cmake"
|
export CMAKE_COMMAND="$CMAKE_COMMAND -D CMAKE_TOOLCHAIN_FILE=cmake/rpi.cmake -DSD_ARM_BUILD=true"
|
||||||
if [ -z "$ARCH" ]; then
|
if [ -z "$ARCH" ]; then
|
||||||
ARCH="armv7-r"
|
ARCH="armv7-r"
|
||||||
fi
|
fi
|
||||||
|
@ -183,6 +183,7 @@ case "$OS" in
|
||||||
if [ -z "$ARCH" ]; then
|
if [ -z "$ARCH" ]; then
|
||||||
ARCH="armv8-a"
|
ARCH="armv8-a"
|
||||||
fi
|
fi
|
||||||
|
export CMAKE_COMMAND="$CMAKE_COMMAND -DSD_ARM_BUILD=true"
|
||||||
;;
|
;;
|
||||||
|
|
||||||
android-arm)
|
android-arm)
|
||||||
|
@ -193,7 +194,7 @@ case "$OS" in
|
||||||
export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/"
|
export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/"
|
||||||
export ANDROID_CC="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/bin/clang"
|
export ANDROID_CC="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/bin/clang"
|
||||||
export ANDROID_ROOT="$ANDROID_NDK/platforms/android-21/arch-arm/"
|
export ANDROID_ROOT="$ANDROID_NDK/platforms/android-21/arch-arm/"
|
||||||
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-arm.cmake -DANDROID_BUILD=true"
|
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-arm.cmake -DSD_ANDROID_BUILD=true"
|
||||||
;;
|
;;
|
||||||
|
|
||||||
android-arm64)
|
android-arm64)
|
||||||
|
@ -204,7 +205,7 @@ case "$OS" in
|
||||||
export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/"
|
export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/"
|
||||||
export ANDROID_CC="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/bin/clang"
|
export ANDROID_CC="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/bin/clang"
|
||||||
export ANDROID_ROOT="$ANDROID_NDK/platforms/android-21/arch-arm64/"
|
export ANDROID_ROOT="$ANDROID_NDK/platforms/android-21/arch-arm64/"
|
||||||
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-arm64.cmake -DANDROID_BUILD=true"
|
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-arm64.cmake -DSD_ANDROID_BUILD=true"
|
||||||
;;
|
;;
|
||||||
|
|
||||||
android-x86)
|
android-x86)
|
||||||
|
@ -215,7 +216,7 @@ case "$OS" in
|
||||||
export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/"
|
export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/"
|
||||||
export ANDROID_CC="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/bin/clang"
|
export ANDROID_CC="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/bin/clang"
|
||||||
export ANDROID_ROOT="$ANDROID_NDK/platforms/android-21/arch-x86/"
|
export ANDROID_ROOT="$ANDROID_NDK/platforms/android-21/arch-x86/"
|
||||||
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-x86.cmake -DANDROID_BUILD=true"
|
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-x86.cmake -DSD_ANDROID_BUILD=true"
|
||||||
;;
|
;;
|
||||||
|
|
||||||
android-x86_64)
|
android-x86_64)
|
||||||
|
@ -226,7 +227,7 @@ case "$OS" in
|
||||||
export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/"
|
export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/"
|
||||||
export ANDROID_CC="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/bin/clang"
|
export ANDROID_CC="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/bin/clang"
|
||||||
export ANDROID_ROOT="$ANDROID_NDK/platforms/android-21/arch-x86_64/"
|
export ANDROID_ROOT="$ANDROID_NDK/platforms/android-21/arch-x86_64/"
|
||||||
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-x86_64.cmake -DANDROID_BUILD=true"
|
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-x86_64.cmake -DSD_ANDROID_BUILD=true"
|
||||||
;;
|
;;
|
||||||
|
|
||||||
ios-x86_64)
|
ios-x86_64)
|
||||||
|
@ -239,7 +240,7 @@ case "$OS" in
|
||||||
fi
|
fi
|
||||||
XCODE_PATH="$(xcode-select --print-path)"
|
XCODE_PATH="$(xcode-select --print-path)"
|
||||||
export IOS_SDK="$XCODE_PATH/Platforms/iPhoneSimulator.platform/Developer/SDKs/iPhoneSimulator$IOS_VERSION.sdk"
|
export IOS_SDK="$XCODE_PATH/Platforms/iPhoneSimulator.platform/Developer/SDKs/iPhoneSimulator$IOS_VERSION.sdk"
|
||||||
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-x86_64.cmake --debug-trycompile -DIOS_BUILD=true"
|
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-x86_64.cmake --debug-trycompile -DSD_IOS_BUILD=true"
|
||||||
;;
|
;;
|
||||||
|
|
||||||
ios-x86)
|
ios-x86)
|
||||||
|
@ -252,7 +253,7 @@ case "$OS" in
|
||||||
fi
|
fi
|
||||||
XCODE_PATH="$(xcode-select --print-path)"
|
XCODE_PATH="$(xcode-select --print-path)"
|
||||||
export IOS_SDK="$XCODE_PATH/Platforms/iPhoneSimulator.platform/Developer/SDKs/iPhoneSimulator$IOS_VERSION.sdk"
|
export IOS_SDK="$XCODE_PATH/Platforms/iPhoneSimulator.platform/Developer/SDKs/iPhoneSimulator$IOS_VERSION.sdk"
|
||||||
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-x86.cmake --debug-trycompile -DIOS_BUILD=true"
|
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-x86.cmake --debug-trycompile -DSD_IOS_BUILD=true"
|
||||||
;;
|
;;
|
||||||
|
|
||||||
ios-arm64)
|
ios-arm64)
|
||||||
|
@ -265,7 +266,7 @@ case "$OS" in
|
||||||
fi
|
fi
|
||||||
XCODE_PATH="$(xcode-select --print-path)"
|
XCODE_PATH="$(xcode-select --print-path)"
|
||||||
export IOS_SDK="$XCODE_PATH/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS$IOS_VERSION.sdk"
|
export IOS_SDK="$XCODE_PATH/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS$IOS_VERSION.sdk"
|
||||||
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-arm64.cmake --debug-trycompile -DIOS_BUILD=true"
|
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-arm64.cmake --debug-trycompile -DSD_IOS_BUILD=true"
|
||||||
;;
|
;;
|
||||||
|
|
||||||
ios-arm)
|
ios-arm)
|
||||||
|
@ -278,7 +279,7 @@ case "$OS" in
|
||||||
fi
|
fi
|
||||||
XCODE_PATH="$(xcode-select --print-path)"
|
XCODE_PATH="$(xcode-select --print-path)"
|
||||||
export IOS_SDK="$XCODE_PATH/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS$IOS_VERSION.sdk"
|
export IOS_SDK="$XCODE_PATH/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS$IOS_VERSION.sdk"
|
||||||
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-arm.cmake --debug-trycompile -DIOS_BUILD=true"
|
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-arm.cmake --debug-trycompile -DSD_IOS_BUILD=true"
|
||||||
;;
|
;;
|
||||||
|
|
||||||
ios-armv7)
|
ios-armv7)
|
||||||
|
@ -288,7 +289,7 @@ case "$OS" in
|
||||||
LIBTYPE="static"
|
LIBTYPE="static"
|
||||||
ARCH="armv7"
|
ARCH="armv7"
|
||||||
export IOS_SDK="/Applications/Xcode.app/Contents/Developer/Platforms/${iPhoneOS}.platform/Developer/SDKs/${iPhoneOS}${IOS_VERSION}.sdk"
|
export IOS_SDK="/Applications/Xcode.app/Contents/Developer/Platforms/${iPhoneOS}.platform/Developer/SDKs/${iPhoneOS}${IOS_VERSION}.sdk"
|
||||||
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-armv7.cmake --debug-trycompile -DIOS_BUILD=true"
|
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-armv7.cmake --debug-trycompile -DSD_IOS_BUILD=true"
|
||||||
;;
|
;;
|
||||||
|
|
||||||
linux*)
|
linux*)
|
||||||
|
@ -298,7 +299,7 @@ case "$OS" in
|
||||||
export CC=clang
|
export CC=clang
|
||||||
export CXX=clang++
|
export CXX=clang++
|
||||||
PARALLEL="true"
|
PARALLEL="true"
|
||||||
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_MACOSX_RPATH=ON -DAPPLE_BUILD=true"
|
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_MACOSX_RPATH=ON -DSD_APPLE_BUILD=true"
|
||||||
;;
|
;;
|
||||||
|
|
||||||
windows*)
|
windows*)
|
||||||
|
@ -375,7 +376,7 @@ fi
|
||||||
OPERATIONS_ARG=
|
OPERATIONS_ARG=
|
||||||
|
|
||||||
if [ -z "$OPERATIONS" ]; then
|
if [ -z "$OPERATIONS" ]; then
|
||||||
OPERATIONS_ARG="-DLIBND4J_ALL_OPS=true"
|
OPERATIONS_ARG="-DSD_ALL_OPS=true"
|
||||||
else
|
else
|
||||||
OPERATIONS_ARG=$OPERATIONS
|
OPERATIONS_ARG=$OPERATIONS
|
||||||
fi
|
fi
|
||||||
|
@ -385,9 +386,9 @@ if [ -z "$EXPERIMENTAL" ]; then
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$CHIP" == "cpu" ]; then
|
if [ "$CHIP" == "cpu" ]; then
|
||||||
BLAS_ARG="-DCPU_BLAS=true -DBLAS=TRUE"
|
BLAS_ARG="-DSD_CPU=true -DBLAS=TRUE"
|
||||||
else
|
else
|
||||||
BLAS_ARG="-DCUDA_BLAS=true -DBLAS=TRUE"
|
BLAS_ARG="-DSD_CUDA=true -DBLAS=TRUE"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -z "$NAME" ]; then
|
if [ -z "$NAME" ]; then
|
||||||
|
@ -399,9 +400,9 @@ if [ -z "$NAME" ]; then
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$LIBTYPE" == "dynamic" ]; then
|
if [ "$LIBTYPE" == "dynamic" ]; then
|
||||||
SHARED_LIBS_ARG="-DBUILD_SHARED_LIBS=OFF"
|
SHARED_LIBS_ARG="-DSD_SHARED_LIB=OFF"
|
||||||
else
|
else
|
||||||
SHARED_LIBS_ARG="-DBUILD_SHARED_LIBS=ON"
|
SHARED_LIBS_ARG="-DSD_SHARED_LIB=ON"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$BUILD" == "release" ]; then
|
if [ "$BUILD" == "release" ]; then
|
||||||
|
@ -428,24 +429,24 @@ if [ "$PACKAGING" == "msi" ]; then
|
||||||
fi
|
fi
|
||||||
|
|
||||||
EXPERIMENTAL_ARG="";
|
EXPERIMENTAL_ARG="";
|
||||||
MINIFIER_ARG="-DLIBND4J_BUILD_MINIFIER=false"
|
MINIFIER_ARG="-DSD_BUILD_MINIFIER=false"
|
||||||
TESTS_ARG="-DBUILD_TESTS=OFF"
|
TESTS_ARG="-DSD_BUILD_TESTS=OFF"
|
||||||
NAME_ARG="-DLIBND4J_NAME=$NAME"
|
NAME_ARG="-DSD_LIBRARY_NAME=$NAME"
|
||||||
|
|
||||||
if [ "$EXPERIMENTAL" == "yes" ]; then
|
if [ "$EXPERIMENTAL" == "yes" ]; then
|
||||||
EXPERIMENTAL_ARG="-DEXPERIMENTAL=yes"
|
EXPERIMENTAL_ARG="-DSD_EXPERIMENTAL=yes"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$MINIFIER" == "true" ]; then
|
if [ "$MINIFIER" == "true" ]; then
|
||||||
MINIFIER_ARG="-DLIBND4J_BUILD_MINIFIER=true"
|
MINIFIER_ARG="-DSD_BUILD_MINIFIER=true"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$TESTS" == "true" ]; then
|
if [ "$TESTS" == "true" ]; then
|
||||||
MINIFIER_ARG="-DLIBND4J_BUILD_MINIFIER=true"
|
MINIFIER_ARG="-DSD_BUILD_MINIFIER=true"
|
||||||
TESTS_ARG="-DBUILD_TESTS=ON"
|
TESTS_ARG="-DSD_BUILD_TESTS=ON"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
ARCH_ARG="-DARCH=$ARCH -DEXTENSION=$CHIP_EXTENSION"
|
ARCH_ARG="-DSD_ARCH=$ARCH -DSD_EXTENSION=$CHIP_EXTENSION"
|
||||||
|
|
||||||
CUDA_COMPUTE="-DCOMPUTE=$COMPUTE"
|
CUDA_COMPUTE="-DCOMPUTE=$COMPUTE"
|
||||||
|
|
||||||
|
@ -536,7 +537,7 @@ echo CHECK_VECTORIZATION = "$CHECK_VECTORIZATION"
|
||||||
echo HELPERS = "$HELPERS"
|
echo HELPERS = "$HELPERS"
|
||||||
mkbuilddir
|
mkbuilddir
|
||||||
pwd
|
pwd
|
||||||
eval $CMAKE_COMMAND "$BLAS_ARG" "$ARCH_ARG" "$NAME_ARG" -DCHECK_VECTORIZATION="${CHECK_VECTORIZATION}" $HELPERS "$SHARED_LIBS_ARG" "$MINIFIER_ARG" "$OPERATIONS_ARG" "$BUILD_TYPE" "$PACKAGING_ARG" "$EXPERIMENTAL_ARG" "$TESTS_ARG" "$CUDA_COMPUTE" -DOPENBLAS_PATH="$OPENBLAS_PATH" -DDEV=FALSE -DCMAKE_NEED_RESPONSE=YES -DMKL_MULTI_THREADED=TRUE ../..
|
eval $CMAKE_COMMAND "$BLAS_ARG" "$ARCH_ARG" "$NAME_ARG" -DSD_CHECK_VECTORIZATION="${CHECK_VECTORIZATION}" $HELPERS "$SHARED_LIBS_ARG" "$MINIFIER_ARG" "$OPERATIONS_ARG" "$BUILD_TYPE" "$PACKAGING_ARG" "$EXPERIMENTAL_ARG" "$TESTS_ARG" "$CUDA_COMPUTE" -DOPENBLAS_PATH="$OPENBLAS_PATH" -DDEV=FALSE -DCMAKE_NEED_RESPONSE=YES -DMKL_MULTI_THREADED=TRUE ../..
|
||||||
|
|
||||||
if [ "$PARALLEL" == "true" ]; then
|
if [ "$PARALLEL" == "true" ]; then
|
||||||
MAKE_ARGUMENTS="$MAKE_ARGUMENTS -j $MAKEJ"
|
MAKE_ARGUMENTS="$MAKE_ARGUMENTS -j $MAKEJ"
|
||||||
|
|
|
@ -21,9 +21,9 @@
|
||||||
#ifndef ND4J_ARRAY_OPTIONS_H
|
#ifndef ND4J_ARRAY_OPTIONS_H
|
||||||
#define ND4J_ARRAY_OPTIONS_H
|
#define ND4J_ARRAY_OPTIONS_H
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
#include <pointercast.h>
|
#include <system/pointercast.h>
|
||||||
#include <dll.h>
|
#include <system/dll.h>
|
||||||
#include <array/DataType.h>
|
#include <array/DataType.h>
|
||||||
#include <array/ArrayType.h>
|
#include <array/ArrayType.h>
|
||||||
#include <array/SpaceType.h>
|
#include <array/SpaceType.h>
|
||||||
|
@ -87,7 +87,7 @@
|
||||||
#define ARRAY_UNSIGNED 8388608
|
#define ARRAY_UNSIGNED 8388608
|
||||||
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
class ND4J_EXPORT ArrayOptions {
|
class ND4J_EXPORT ArrayOptions {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -104,7 +104,7 @@ namespace nd4j {
|
||||||
static FORCEINLINE _CUDA_HD bool isSparseArray(Nd4jLong *shapeInfo);
|
static FORCEINLINE _CUDA_HD bool isSparseArray(Nd4jLong *shapeInfo);
|
||||||
static FORCEINLINE _CUDA_HD bool isUnsigned(Nd4jLong *shapeInfo);
|
static FORCEINLINE _CUDA_HD bool isUnsigned(Nd4jLong *shapeInfo);
|
||||||
|
|
||||||
static FORCEINLINE _CUDA_HD nd4j::DataType dataType(const Nd4jLong *shapeInfo);
|
static FORCEINLINE _CUDA_HD sd::DataType dataType(const Nd4jLong *shapeInfo);
|
||||||
|
|
||||||
static FORCEINLINE _CUDA_HD SpaceType spaceType(Nd4jLong *shapeInfo);
|
static FORCEINLINE _CUDA_HD SpaceType spaceType(Nd4jLong *shapeInfo);
|
||||||
static FORCEINLINE _CUDA_HD SpaceType spaceType(const Nd4jLong *shapeInfo);
|
static FORCEINLINE _CUDA_HD SpaceType spaceType(const Nd4jLong *shapeInfo);
|
||||||
|
@ -119,7 +119,7 @@ namespace nd4j {
|
||||||
|
|
||||||
|
|
||||||
static FORCEINLINE _CUDA_HD void resetDataType(Nd4jLong *shapeInfo);
|
static FORCEINLINE _CUDA_HD void resetDataType(Nd4jLong *shapeInfo);
|
||||||
static FORCEINLINE _CUDA_HD void setDataType(Nd4jLong *shapeInfo, const nd4j::DataType dataType);
|
static FORCEINLINE _CUDA_HD void setDataType(Nd4jLong *shapeInfo, const sd::DataType dataType);
|
||||||
|
|
||||||
static FORCEINLINE _CUDA_HD void copyDataType(Nd4jLong* to, const Nd4jLong* from);
|
static FORCEINLINE _CUDA_HD void copyDataType(Nd4jLong* to, const Nd4jLong* from);
|
||||||
};
|
};
|
||||||
|
@ -155,34 +155,34 @@ namespace nd4j {
|
||||||
return hasPropertyBitSet(shapeInfo, ARRAY_UNSIGNED);
|
return hasPropertyBitSet(shapeInfo, ARRAY_UNSIGNED);
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCEINLINE _CUDA_HD nd4j::DataType ArrayOptions::dataType(const Nd4jLong *shapeInfo) {
|
FORCEINLINE _CUDA_HD sd::DataType ArrayOptions::dataType(const Nd4jLong *shapeInfo) {
|
||||||
/*if (hasPropertyBitSet(shapeInfo, ARRAY_QUANTIZED))
|
/*if (hasPropertyBitSet(shapeInfo, ARRAY_QUANTIZED))
|
||||||
return nd4j::DataType::QINT8;
|
return sd::DataType::QINT8;
|
||||||
else */if (hasPropertyBitSet(shapeInfo, ARRAY_FLOAT))
|
else */if (hasPropertyBitSet(shapeInfo, ARRAY_FLOAT))
|
||||||
return nd4j::DataType::FLOAT32;
|
return sd::DataType::FLOAT32;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_DOUBLE))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_DOUBLE))
|
||||||
return nd4j::DataType::DOUBLE;
|
return sd::DataType::DOUBLE;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_HALF))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_HALF))
|
||||||
return nd4j::DataType::HALF;
|
return sd::DataType::HALF;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_BHALF))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_BHALF))
|
||||||
return nd4j::DataType::BFLOAT16;
|
return sd::DataType::BFLOAT16;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_BOOL))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_BOOL))
|
||||||
return nd4j::DataType ::BOOL;
|
return sd::DataType ::BOOL;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_UNSIGNED)) {
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UNSIGNED)) {
|
||||||
if (hasPropertyBitSet(shapeInfo, ARRAY_CHAR))
|
if (hasPropertyBitSet(shapeInfo, ARRAY_CHAR))
|
||||||
return nd4j::DataType ::UINT8;
|
return sd::DataType ::UINT8;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_SHORT))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_SHORT))
|
||||||
return nd4j::DataType ::UINT16;
|
return sd::DataType ::UINT16;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_INT))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_INT))
|
||||||
return nd4j::DataType ::UINT32;
|
return sd::DataType ::UINT32;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG))
|
||||||
return nd4j::DataType ::UINT64;
|
return sd::DataType ::UINT64;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8))
|
||||||
return nd4j::DataType ::UTF8;
|
return sd::DataType ::UTF8;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16))
|
||||||
return nd4j::DataType ::UTF16;
|
return sd::DataType ::UTF16;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32))
|
||||||
return nd4j::DataType ::UTF32;
|
return sd::DataType ::UTF32;
|
||||||
else {
|
else {
|
||||||
//shape::printShapeInfoLinear("Bad unsigned datatype (not)stored in shape", const_cast<Nd4jLong*>(shapeInfo));
|
//shape::printShapeInfoLinear("Bad unsigned datatype (not)stored in shape", const_cast<Nd4jLong*>(shapeInfo));
|
||||||
#ifndef __CUDA_ARCH__
|
#ifndef __CUDA_ARCH__
|
||||||
|
@ -191,19 +191,19 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_CHAR))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_CHAR))
|
||||||
return nd4j::DataType::INT8;
|
return sd::DataType::INT8;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_SHORT))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_SHORT))
|
||||||
return nd4j::DataType::INT16;
|
return sd::DataType::INT16;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_INT))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_INT))
|
||||||
return nd4j::DataType::INT32;
|
return sd::DataType::INT32;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG))
|
||||||
return nd4j::DataType::INT64;
|
return sd::DataType::INT64;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8))
|
||||||
return nd4j::DataType::UTF8;
|
return sd::DataType::UTF8;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16))
|
||||||
return nd4j::DataType::UTF16;
|
return sd::DataType::UTF16;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32))
|
||||||
return nd4j::DataType::UTF32;
|
return sd::DataType::UTF32;
|
||||||
else {
|
else {
|
||||||
//shape::printShapeInfoLinear("Bad signed datatype (not)stored in shape", const_cast<Nd4jLong*>(shapeInfo));
|
//shape::printShapeInfoLinear("Bad signed datatype (not)stored in shape", const_cast<Nd4jLong*>(shapeInfo));
|
||||||
#ifndef __CUDA_ARCH__
|
#ifndef __CUDA_ARCH__
|
||||||
|
@ -296,63 +296,63 @@ namespace nd4j {
|
||||||
unsetPropertyBit(shapeInfo, ARRAY_UNSIGNED);
|
unsetPropertyBit(shapeInfo, ARRAY_UNSIGNED);
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCEINLINE _CUDA_HD void ArrayOptions::setDataType(Nd4jLong *shapeInfo, const nd4j::DataType dataType) {
|
FORCEINLINE _CUDA_HD void ArrayOptions::setDataType(Nd4jLong *shapeInfo, const sd::DataType dataType) {
|
||||||
resetDataType(shapeInfo);
|
resetDataType(shapeInfo);
|
||||||
if (dataType == nd4j::DataType::UINT8 ||
|
if (dataType == sd::DataType::UINT8 ||
|
||||||
dataType == nd4j::DataType::UINT16 ||
|
dataType == sd::DataType::UINT16 ||
|
||||||
dataType == nd4j::DataType::UINT32 ||
|
dataType == sd::DataType::UINT32 ||
|
||||||
dataType == nd4j::DataType::UINT64) {
|
dataType == sd::DataType::UINT64) {
|
||||||
|
|
||||||
setPropertyBit(shapeInfo, ARRAY_UNSIGNED);
|
setPropertyBit(shapeInfo, ARRAY_UNSIGNED);
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (dataType) {
|
switch (dataType) {
|
||||||
case nd4j::DataType::BOOL:
|
case sd::DataType::BOOL:
|
||||||
setPropertyBit(shapeInfo, ARRAY_BOOL);
|
setPropertyBit(shapeInfo, ARRAY_BOOL);
|
||||||
break;
|
break;
|
||||||
case nd4j::DataType::HALF:
|
case sd::DataType::HALF:
|
||||||
setPropertyBit(shapeInfo, ARRAY_HALF);
|
setPropertyBit(shapeInfo, ARRAY_HALF);
|
||||||
break;
|
break;
|
||||||
case nd4j::DataType::BFLOAT16:
|
case sd::DataType::BFLOAT16:
|
||||||
setPropertyBit(shapeInfo, ARRAY_BHALF);
|
setPropertyBit(shapeInfo, ARRAY_BHALF);
|
||||||
break;
|
break;
|
||||||
case nd4j::DataType::FLOAT32:
|
case sd::DataType::FLOAT32:
|
||||||
setPropertyBit(shapeInfo, ARRAY_FLOAT);
|
setPropertyBit(shapeInfo, ARRAY_FLOAT);
|
||||||
break;
|
break;
|
||||||
case nd4j::DataType::DOUBLE:
|
case sd::DataType::DOUBLE:
|
||||||
setPropertyBit(shapeInfo, ARRAY_DOUBLE);
|
setPropertyBit(shapeInfo, ARRAY_DOUBLE);
|
||||||
break;
|
break;
|
||||||
case nd4j::DataType::INT8:
|
case sd::DataType::INT8:
|
||||||
setPropertyBit(shapeInfo, ARRAY_CHAR);
|
setPropertyBit(shapeInfo, ARRAY_CHAR);
|
||||||
break;
|
break;
|
||||||
case nd4j::DataType::INT16:
|
case sd::DataType::INT16:
|
||||||
setPropertyBit(shapeInfo, ARRAY_SHORT);
|
setPropertyBit(shapeInfo, ARRAY_SHORT);
|
||||||
break;
|
break;
|
||||||
case nd4j::DataType::INT32:
|
case sd::DataType::INT32:
|
||||||
setPropertyBit(shapeInfo, ARRAY_INT);
|
setPropertyBit(shapeInfo, ARRAY_INT);
|
||||||
break;
|
break;
|
||||||
case nd4j::DataType::INT64:
|
case sd::DataType::INT64:
|
||||||
setPropertyBit(shapeInfo, ARRAY_LONG);
|
setPropertyBit(shapeInfo, ARRAY_LONG);
|
||||||
break;
|
break;
|
||||||
case nd4j::DataType::UINT8:
|
case sd::DataType::UINT8:
|
||||||
setPropertyBit(shapeInfo, ARRAY_CHAR);
|
setPropertyBit(shapeInfo, ARRAY_CHAR);
|
||||||
break;
|
break;
|
||||||
case nd4j::DataType::UINT16:
|
case sd::DataType::UINT16:
|
||||||
setPropertyBit(shapeInfo, ARRAY_SHORT);
|
setPropertyBit(shapeInfo, ARRAY_SHORT);
|
||||||
break;
|
break;
|
||||||
case nd4j::DataType::UINT32:
|
case sd::DataType::UINT32:
|
||||||
setPropertyBit(shapeInfo, ARRAY_INT);
|
setPropertyBit(shapeInfo, ARRAY_INT);
|
||||||
break;
|
break;
|
||||||
case nd4j::DataType::UINT64:
|
case sd::DataType::UINT64:
|
||||||
setPropertyBit(shapeInfo, ARRAY_LONG);
|
setPropertyBit(shapeInfo, ARRAY_LONG);
|
||||||
break;
|
break;
|
||||||
case nd4j::DataType::UTF8:
|
case sd::DataType::UTF8:
|
||||||
setPropertyBit(shapeInfo, ARRAY_UTF8);
|
setPropertyBit(shapeInfo, ARRAY_UTF8);
|
||||||
break;
|
break;
|
||||||
case nd4j::DataType::UTF16:
|
case sd::DataType::UTF16:
|
||||||
setPropertyBit(shapeInfo, ARRAY_UTF16);
|
setPropertyBit(shapeInfo, ARRAY_UTF16);
|
||||||
break;
|
break;
|
||||||
case nd4j::DataType::UTF32:
|
case sd::DataType::UTF32:
|
||||||
setPropertyBit(shapeInfo, ARRAY_UTF32);
|
setPropertyBit(shapeInfo, ARRAY_UTF32);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#ifndef ND4J_ARRAY_TYPE_H
|
#ifndef ND4J_ARRAY_TYPE_H
|
||||||
#define ND4J_ARRAY_TYPE_H
|
#define ND4J_ARRAY_TYPE_H
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
enum ArrayType {
|
enum ArrayType {
|
||||||
DENSE = 1,
|
DENSE = 1,
|
||||||
SPARSE = 2,
|
SPARSE = 2,
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#ifndef LIBND4J_BYTEORDER_H
|
#ifndef LIBND4J_BYTEORDER_H
|
||||||
#define LIBND4J_BYTEORDER_H
|
#define LIBND4J_BYTEORDER_H
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
enum ByteOrder {
|
enum ByteOrder {
|
||||||
LE = 0,
|
LE = 0,
|
||||||
BE = 1,
|
BE = 1,
|
||||||
|
|
|
@ -23,12 +23,12 @@
|
||||||
|
|
||||||
#include <graph/generated/array_generated.h>
|
#include <graph/generated/array_generated.h>
|
||||||
#include "ByteOrder.h"
|
#include "ByteOrder.h"
|
||||||
#include <dll.h>
|
#include <system/dll.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
class ND4J_EXPORT ByteOrderUtils {
|
class ND4J_EXPORT ByteOrderUtils {
|
||||||
public:
|
public:
|
||||||
static ByteOrder fromFlatByteOrder(nd4j::graph::ByteOrder order);
|
static ByteOrder fromFlatByteOrder(sd::graph::ByteOrder order);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,11 +20,11 @@
|
||||||
#ifndef LIBND4J_CONSTANTDATABUFFER_H
|
#ifndef LIBND4J_CONSTANTDATABUFFER_H
|
||||||
#define LIBND4J_CONSTANTDATABUFFER_H
|
#define LIBND4J_CONSTANTDATABUFFER_H
|
||||||
|
|
||||||
#include <dll.h>
|
#include <system/dll.h>
|
||||||
#include <pointercast.h>
|
#include <system/pointercast.h>
|
||||||
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
class ND4J_EXPORT ConstantDataBuffer {
|
class ND4J_EXPORT ConstantDataBuffer {
|
||||||
private:
|
private:
|
||||||
Nd4jPointer _primaryBuffer = nullptr;
|
Nd4jPointer _primaryBuffer = nullptr;
|
||||||
|
|
|
@ -24,11 +24,11 @@
|
||||||
#include <array/DataType.h>
|
#include <array/DataType.h>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <pointercast.h>
|
#include <system/pointercast.h>
|
||||||
#include <dll.h>
|
#include <system/dll.h>
|
||||||
#include <array/ConstantDataBuffer.h>
|
#include <array/ConstantDataBuffer.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
class ND4J_EXPORT ConstantDescriptor {
|
class ND4J_EXPORT ConstantDescriptor {
|
||||||
private:
|
private:
|
||||||
std::vector<Nd4jLong> _integerValues;
|
std::vector<Nd4jLong> _integerValues;
|
||||||
|
@ -59,5 +59,17 @@ namespace nd4j {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifndef __JAVACPP_HACK__
|
||||||
|
|
||||||
|
namespace std {
|
||||||
|
template<>
|
||||||
|
class ND4J_EXPORT hash<sd::ConstantDescriptor> {
|
||||||
|
public:
|
||||||
|
size_t operator()(const sd::ConstantDescriptor &k) const;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
#endif //DEV_TESTS_CONSTANTDESCRIPTOR_H
|
#endif //DEV_TESTS_CONSTANTDESCRIPTOR_H
|
||||||
|
|
|
@ -27,13 +27,13 @@
|
||||||
#include <array/ConstantDataBuffer.h>
|
#include <array/ConstantDataBuffer.h>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
class ConstantHolder {
|
class ConstantHolder {
|
||||||
private:
|
private:
|
||||||
int _deviceId = 0;
|
int _deviceId = 0;
|
||||||
std::mutex _mutex;
|
std::mutex _mutex;
|
||||||
|
|
||||||
std::map<nd4j::DataType, ConstantDataBuffer> _buffers;
|
std::map<sd::DataType, ConstantDataBuffer> _buffers;
|
||||||
public:
|
public:
|
||||||
ConstantHolder(const ConstantHolder& other);
|
ConstantHolder(const ConstantHolder& other);
|
||||||
ConstantHolder() = default;
|
ConstantHolder() = default;
|
||||||
|
@ -42,17 +42,17 @@ namespace nd4j {
|
||||||
ConstantHolder& operator=(const ConstantHolder& other) = default;
|
ConstantHolder& operator=(const ConstantHolder& other) = default;
|
||||||
ConstantHolder& operator=(ConstantHolder&& other) = default;
|
ConstantHolder& operator=(ConstantHolder&& other) = default;
|
||||||
|
|
||||||
bool hasBuffer(nd4j::DataType dataType);
|
bool hasBuffer(sd::DataType dataType);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool hasBuffer();
|
bool hasBuffer();
|
||||||
|
|
||||||
void addBuffer(ConstantDataBuffer &pointer, nd4j::DataType dataType);
|
void addBuffer(ConstantDataBuffer &pointer, sd::DataType dataType);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void addBuffer(ConstantDataBuffer &pointer);
|
void addBuffer(ConstantDataBuffer &pointer);
|
||||||
|
|
||||||
ConstantDataBuffer* getConstantDataBuffer(nd4j::DataType dataType);
|
ConstantDataBuffer* getConstantDataBuffer(sd::DataType dataType);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ConstantDataBuffer* getConstantDataBuffer();
|
ConstantDataBuffer* getConstantDataBuffer();
|
||||||
|
|
|
@ -23,14 +23,14 @@
|
||||||
#define DEV_TESTS_DATABUFFER_H
|
#define DEV_TESTS_DATABUFFER_H
|
||||||
|
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
#include <dll.h>
|
#include <system/dll.h>
|
||||||
#include <pointercast.h>
|
#include <system/pointercast.h>
|
||||||
#include <array/DataType.h>
|
#include <array/DataType.h>
|
||||||
#include <memory/Workspace.h>
|
#include <memory/Workspace.h>
|
||||||
#include <execution/LaunchContext.h>
|
#include <execution/LaunchContext.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
|
|
||||||
class ND4J_EXPORT DataBuffer {
|
class ND4J_EXPORT DataBuffer {
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#ifndef ND4J_DATATYPE_H
|
#ifndef ND4J_DATATYPE_H
|
||||||
#define ND4J_DATATYPE_H
|
#define ND4J_DATATYPE_H
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
enum DataType {
|
enum DataType {
|
||||||
INHERIT = 0,
|
INHERIT = 0,
|
||||||
BOOL = 1,
|
BOOL = 1,
|
||||||
|
|
|
@ -21,17 +21,17 @@
|
||||||
#ifndef LIBND4J_DATATYPECONVERSIONS_H
|
#ifndef LIBND4J_DATATYPECONVERSIONS_H
|
||||||
#define LIBND4J_DATATYPECONVERSIONS_H
|
#define LIBND4J_DATATYPECONVERSIONS_H
|
||||||
|
|
||||||
#include <pointercast.h>
|
#include <system/pointercast.h>
|
||||||
#include <helpers/logger.h>
|
#include <helpers/logger.h>
|
||||||
#include <op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
#include <array/DataType.h>
|
#include <array/DataType.h>
|
||||||
#include <types/float16.h>
|
#include <types/float16.h>
|
||||||
#include <helpers/BitwiseUtils.h>
|
#include <helpers/BitwiseUtils.h>
|
||||||
#include <loops/type_conversions.h>
|
#include <loops/type_conversions.h>
|
||||||
#include <dll.h>
|
#include <system/dll.h>
|
||||||
#include <execution/Threads.h>
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class ND4J_EXPORT DataTypeConversions {
|
class ND4J_EXPORT DataTypeConversions {
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -26,20 +26,20 @@
|
||||||
#include <types/bfloat16.h>
|
#include <types/bfloat16.h>
|
||||||
#include <array/DataType.h>
|
#include <array/DataType.h>
|
||||||
#include <graph/generated/array_generated.h>
|
#include <graph/generated/array_generated.h>
|
||||||
#include <op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
#include <dll.h>
|
#include <system/dll.h>
|
||||||
#include <Environment.h>
|
#include <system/Environment.h>
|
||||||
#include <ArrayOptions.h>
|
#include <array/ArrayOptions.h>
|
||||||
//#include <templatemath.h>
|
//#include <templatemath.h>
|
||||||
//#include <shape.h>
|
//#include <helpers/shape.h>
|
||||||
#include <helpers/logger.h>
|
#include <helpers/logger.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
class ND4J_EXPORT DataTypeUtils {
|
class ND4J_EXPORT DataTypeUtils {
|
||||||
public:
|
public:
|
||||||
static int asInt(DataType type);
|
static int asInt(DataType type);
|
||||||
static DataType fromInt(int dtype);
|
static DataType fromInt(int dtype);
|
||||||
static DataType fromFlatDataType(nd4j::graph::DType dtype);
|
static DataType fromFlatDataType(sd::graph::DType dtype);
|
||||||
FORCEINLINE static std::string asString(DataType dataType);
|
FORCEINLINE static std::string asString(DataType dataType);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -70,21 +70,21 @@ namespace nd4j {
|
||||||
FORCEINLINE static _CUDA_HD size_t sizeOf(DataType type);
|
FORCEINLINE static _CUDA_HD size_t sizeOf(DataType type);
|
||||||
FORCEINLINE static _CUDA_HD size_t sizeOf(const Nd4jLong* shapeInfo);
|
FORCEINLINE static _CUDA_HD size_t sizeOf(const Nd4jLong* shapeInfo);
|
||||||
|
|
||||||
FORCEINLINE static _CUDA_HD bool isR(nd4j::DataType dataType);
|
FORCEINLINE static _CUDA_HD bool isR(sd::DataType dataType);
|
||||||
|
|
||||||
FORCEINLINE static _CUDA_HD bool isZ(nd4j::DataType dataType);
|
FORCEINLINE static _CUDA_HD bool isZ(sd::DataType dataType);
|
||||||
|
|
||||||
FORCEINLINE static _CUDA_HD bool isB(nd4j::DataType dataType);
|
FORCEINLINE static _CUDA_HD bool isB(sd::DataType dataType);
|
||||||
|
|
||||||
FORCEINLINE static _CUDA_HD bool isU(nd4j::DataType dataType);
|
FORCEINLINE static _CUDA_HD bool isU(sd::DataType dataType);
|
||||||
|
|
||||||
FORCEINLINE static _CUDA_HD bool isS(nd4j::DataType dataType);
|
FORCEINLINE static _CUDA_HD bool isS(sd::DataType dataType);
|
||||||
|
|
||||||
FORCEINLINE static nd4j::DataType pickPairwiseResultType(nd4j::DataType typeX, nd4j::DataType typeY);
|
FORCEINLINE static sd::DataType pickPairwiseResultType(sd::DataType typeX, sd::DataType typeY);
|
||||||
|
|
||||||
FORCEINLINE static nd4j::DataType pickPairwiseResultType(const Nd4jLong* shapeInfo1, const Nd4jLong* shapeInfo2);
|
FORCEINLINE static sd::DataType pickPairwiseResultType(const Nd4jLong* shapeInfo1, const Nd4jLong* shapeInfo2);
|
||||||
|
|
||||||
FORCEINLINE static nd4j::DataType pickFloatingType(nd4j::DataType typeX);
|
FORCEINLINE static sd::DataType pickFloatingType(sd::DataType typeX);
|
||||||
|
|
||||||
template <typename T1, typename T2>
|
template <typename T1, typename T2>
|
||||||
FORCEINLINE static std::vector<T2> convertVector(const std::vector<T1> &vector);
|
FORCEINLINE static std::vector<T2> convertVector(const std::vector<T1> &vector);
|
||||||
|
@ -106,38 +106,38 @@ namespace nd4j {
|
||||||
///// IMLEMENTATION OF INLINE METHODS /////
|
///// IMLEMENTATION OF INLINE METHODS /////
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
FORCEINLINE nd4j::DataType DataTypeUtils::pickFloatingType(nd4j::DataType typeX) {
|
FORCEINLINE sd::DataType DataTypeUtils::pickFloatingType(sd::DataType typeX) {
|
||||||
// if proposed dataType is already floating point - return it
|
// if proposed dataType is already floating point - return it
|
||||||
if (isR(typeX))
|
if (isR(typeX))
|
||||||
return typeX;
|
return typeX;
|
||||||
return Environment::getInstance()->defaultFloatDataType();
|
return Environment::getInstance()->defaultFloatDataType();
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCEINLINE bool DataTypeUtils::isR(nd4j::DataType dataType) {
|
FORCEINLINE bool DataTypeUtils::isR(sd::DataType dataType) {
|
||||||
return dataType == nd4j::DataType::FLOAT32 || dataType == nd4j::DataType::BFLOAT16 || dataType == nd4j::DataType::HALF || dataType == nd4j::DataType::DOUBLE;
|
return dataType == sd::DataType::FLOAT32 || dataType == sd::DataType::BFLOAT16 || dataType == sd::DataType::HALF || dataType == sd::DataType::DOUBLE;
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCEINLINE bool DataTypeUtils::isB(nd4j::DataType dataType) {
|
FORCEINLINE bool DataTypeUtils::isB(sd::DataType dataType) {
|
||||||
return dataType == nd4j::DataType::BOOL;
|
return dataType == sd::DataType::BOOL;
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCEINLINE bool DataTypeUtils::isS(nd4j::DataType dataType) {
|
FORCEINLINE bool DataTypeUtils::isS(sd::DataType dataType) {
|
||||||
return dataType == nd4j::DataType::UTF8 || dataType == nd4j::DataType::UTF16 || dataType == nd4j::DataType::UTF32;
|
return dataType == sd::DataType::UTF8 || dataType == sd::DataType::UTF16 || dataType == sd::DataType::UTF32;
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCEINLINE bool DataTypeUtils::isZ(nd4j::DataType dataType) {
|
FORCEINLINE bool DataTypeUtils::isZ(sd::DataType dataType) {
|
||||||
return !isR(dataType) && !isB(dataType) && !isS(dataType);
|
return !isR(dataType) && !isB(dataType) && !isS(dataType);
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCEINLINE bool DataTypeUtils::isU(nd4j::DataType dataType) {
|
FORCEINLINE bool DataTypeUtils::isU(sd::DataType dataType) {
|
||||||
return dataType == nd4j::DataType::UINT8 || dataType == nd4j::DataType::UINT16 || dataType == nd4j::DataType::UINT32 || dataType == nd4j::DataType::UINT64;
|
return dataType == sd::DataType::UINT8 || dataType == sd::DataType::UINT16 || dataType == sd::DataType::UINT32 || dataType == sd::DataType::UINT64;
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCEINLINE nd4j::DataType DataTypeUtils::pickPairwiseResultType(nd4j::DataType typeX, nd4j::DataType typeY) {
|
FORCEINLINE sd::DataType DataTypeUtils::pickPairwiseResultType(sd::DataType typeX, sd::DataType typeY) {
|
||||||
// if both dtypes are the same - just return it
|
// if both dtypes are the same - just return it
|
||||||
if (typeX == typeY)
|
if (typeX == typeY)
|
||||||
return typeX;
|
return typeX;
|
||||||
auto nd4j_max = [](nd4j::DataType typeX, nd4j::DataType typeY) {
|
auto nd4j_max = [](sd::DataType typeX, sd::DataType typeY) {
|
||||||
return typeX > typeY?typeX:typeY;
|
return typeX > typeY?typeX:typeY;
|
||||||
};
|
};
|
||||||
auto rX = isR(typeX);
|
auto rX = isR(typeX);
|
||||||
|
@ -154,7 +154,7 @@ namespace nd4j {
|
||||||
// if both data types are float - return biggest one
|
// if both data types are float - return biggest one
|
||||||
if (rX && rY) {
|
if (rX && rY) {
|
||||||
// if we allow precision boost, then we pick bigger data type
|
// if we allow precision boost, then we pick bigger data type
|
||||||
if (nd4j::Environment::getInstance()->precisionBoostAllowed()) {
|
if (sd::Environment::getInstance()->precisionBoostAllowed()) {
|
||||||
return nd4j_max(typeX, typeY);
|
return nd4j_max(typeX, typeY);
|
||||||
} else {
|
} else {
|
||||||
// and we return first operand otherwise
|
// and we return first operand otherwise
|
||||||
|
@ -165,7 +165,7 @@ namespace nd4j {
|
||||||
|
|
||||||
// if that's not real type, we apply same rules
|
// if that's not real type, we apply same rules
|
||||||
if (!rX && !rY) {
|
if (!rX && !rY) {
|
||||||
if (nd4j::Environment::getInstance()->precisionBoostAllowed()) {
|
if (sd::Environment::getInstance()->precisionBoostAllowed()) {
|
||||||
return nd4j_max(typeX, typeY);
|
return nd4j_max(typeX, typeY);
|
||||||
} else {
|
} else {
|
||||||
// and we return first operand otherwise
|
// and we return first operand otherwise
|
||||||
|
@ -177,7 +177,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
FORCEINLINE nd4j::DataType DataTypeUtils::pickPairwiseResultType(const Nd4jLong* shapeInfo1, const Nd4jLong* shapeInfo2) {
|
FORCEINLINE sd::DataType DataTypeUtils::pickPairwiseResultType(const Nd4jLong* shapeInfo1, const Nd4jLong* shapeInfo2) {
|
||||||
|
|
||||||
return pickPairwiseResultType(ArrayOptions::dataType(shapeInfo1), ArrayOptions::dataType(shapeInfo2));
|
return pickPairwiseResultType(ArrayOptions::dataType(shapeInfo1), ArrayOptions::dataType(shapeInfo2));
|
||||||
}
|
}
|
||||||
|
@ -420,31 +420,31 @@ FORCEINLINE _CUDA_HD T DataTypeUtils::eps() {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCEINLINE _CUDA_HD size_t DataTypeUtils::sizeOfElement(nd4j::DataType type) {
|
FORCEINLINE _CUDA_HD size_t DataTypeUtils::sizeOfElement(sd::DataType type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case nd4j::DataType::UINT8:
|
case sd::DataType::UINT8:
|
||||||
case nd4j::DataType::INT8:
|
case sd::DataType::INT8:
|
||||||
case nd4j::DataType::FLOAT8:
|
case sd::DataType::FLOAT8:
|
||||||
case nd4j::DataType::QINT8:
|
case sd::DataType::QINT8:
|
||||||
case nd4j::DataType::BOOL: return (size_t) 1;
|
case sd::DataType::BOOL: return (size_t) 1;
|
||||||
|
|
||||||
case nd4j::DataType::BFLOAT16:
|
case sd::DataType::BFLOAT16:
|
||||||
case nd4j::DataType::HALF:
|
case sd::DataType::HALF:
|
||||||
case nd4j::DataType::INT16:
|
case sd::DataType::INT16:
|
||||||
case nd4j::DataType::QINT16:
|
case sd::DataType::QINT16:
|
||||||
case nd4j::DataType::UINT16: return (size_t) 2;
|
case sd::DataType::UINT16: return (size_t) 2;
|
||||||
|
|
||||||
case nd4j::DataType::UTF8:
|
case sd::DataType::UTF8:
|
||||||
case nd4j::DataType::UTF16:
|
case sd::DataType::UTF16:
|
||||||
case nd4j::DataType::UTF32:
|
case sd::DataType::UTF32:
|
||||||
case nd4j::DataType::INT32:
|
case sd::DataType::INT32:
|
||||||
case nd4j::DataType::UINT32:
|
case sd::DataType::UINT32:
|
||||||
case nd4j::DataType::HALF2:
|
case sd::DataType::HALF2:
|
||||||
case nd4j::DataType::FLOAT32: return (size_t) 4;
|
case sd::DataType::FLOAT32: return (size_t) 4;
|
||||||
|
|
||||||
case nd4j::DataType::UINT64:
|
case sd::DataType::UINT64:
|
||||||
case nd4j::DataType::INT64:
|
case sd::DataType::INT64:
|
||||||
case nd4j::DataType::DOUBLE: return (size_t) 8;
|
case sd::DataType::DOUBLE: return (size_t) 8;
|
||||||
|
|
||||||
default: {
|
default: {
|
||||||
nd4j_printf("Unknown DataType used: [%i]\n", asInt(type));
|
nd4j_printf("Unknown DataType used: [%i]\n", asInt(type));
|
||||||
|
@ -456,41 +456,41 @@ FORCEINLINE _CUDA_HD T DataTypeUtils::eps() {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
FORCEINLINE _CUDA_HD nd4j::DataType nd4j::DataTypeUtils::fromT() {
|
FORCEINLINE _CUDA_HD sd::DataType sd::DataTypeUtils::fromT() {
|
||||||
if (std::is_same<T, bool>::value) {
|
if (std::is_same<T, bool>::value) {
|
||||||
return nd4j::DataType::BOOL;
|
return sd::DataType::BOOL;
|
||||||
} else if (std::is_same<T, std::string>::value) {
|
} else if (std::is_same<T, std::string>::value) {
|
||||||
return nd4j::DataType::UTF8;
|
return sd::DataType::UTF8;
|
||||||
} else if (std::is_same<T, std::u16string>::value) {
|
} else if (std::is_same<T, std::u16string>::value) {
|
||||||
return nd4j::DataType::UTF16;
|
return sd::DataType::UTF16;
|
||||||
} else if (std::is_same<T, std::u32string>::value) {
|
} else if (std::is_same<T, std::u32string>::value) {
|
||||||
return nd4j::DataType::UTF32;
|
return sd::DataType::UTF32;
|
||||||
} else if (std::is_same<T, float>::value) {
|
} else if (std::is_same<T, float>::value) {
|
||||||
return nd4j::DataType::FLOAT32;
|
return sd::DataType::FLOAT32;
|
||||||
} else if (std::is_same<T, float16>::value) {
|
} else if (std::is_same<T, float16>::value) {
|
||||||
return nd4j::DataType::HALF;
|
return sd::DataType::HALF;
|
||||||
} else if (std::is_same<T, bfloat16>::value) {
|
} else if (std::is_same<T, bfloat16>::value) {
|
||||||
return nd4j::DataType::BFLOAT16;
|
return sd::DataType::BFLOAT16;
|
||||||
} else if (std::is_same<T, double>::value) {
|
} else if (std::is_same<T, double>::value) {
|
||||||
return nd4j::DataType::DOUBLE;
|
return sd::DataType::DOUBLE;
|
||||||
} else if (std::is_same<T, int8_t >::value) {
|
} else if (std::is_same<T, int8_t >::value) {
|
||||||
return nd4j::DataType::INT8;
|
return sd::DataType::INT8;
|
||||||
} else if (std::is_same<T, int16_t >::value) {
|
} else if (std::is_same<T, int16_t >::value) {
|
||||||
return nd4j::DataType::INT16;
|
return sd::DataType::INT16;
|
||||||
} else if (std::is_same<T, int>::value) {
|
} else if (std::is_same<T, int>::value) {
|
||||||
return nd4j::DataType::INT32;
|
return sd::DataType::INT32;
|
||||||
} else if (std::is_same<T, Nd4jLong>::value) {
|
} else if (std::is_same<T, Nd4jLong>::value) {
|
||||||
return nd4j::DataType::INT64;
|
return sd::DataType::INT64;
|
||||||
} else if (std::is_same<T, uint8_t>::value) {
|
} else if (std::is_same<T, uint8_t>::value) {
|
||||||
return nd4j::DataType::UINT8;
|
return sd::DataType::UINT8;
|
||||||
} else if (std::is_same<T, uint16_t>::value) {
|
} else if (std::is_same<T, uint16_t>::value) {
|
||||||
return nd4j::DataType::UINT16;
|
return sd::DataType::UINT16;
|
||||||
} else if (std::is_same<T, uint32_t>::value) {
|
} else if (std::is_same<T, uint32_t>::value) {
|
||||||
return nd4j::DataType::UINT32;
|
return sd::DataType::UINT32;
|
||||||
} else if (std::is_same<T, Nd4jULong>::value) {
|
} else if (std::is_same<T, Nd4jULong>::value) {
|
||||||
return nd4j::DataType::UINT64;
|
return sd::DataType::UINT64;
|
||||||
} else {
|
} else {
|
||||||
return nd4j::DataType::INHERIT;
|
return sd::DataType::INHERIT;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,14 +21,14 @@
|
||||||
#ifndef DEV_TESTS_EXTRAARGUMENTS_H
|
#ifndef DEV_TESTS_EXTRAARGUMENTS_H
|
||||||
#define DEV_TESTS_EXTRAARGUMENTS_H
|
#define DEV_TESTS_EXTRAARGUMENTS_H
|
||||||
|
|
||||||
#include <dll.h>
|
#include <system/dll.h>
|
||||||
#include <initializer_list>
|
#include <initializer_list>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <array/DataType.h>
|
#include <array/DataType.h>
|
||||||
#include <pointercast.h>
|
#include <system/pointercast.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
class ND4J_EXPORT ExtraArguments {
|
class ND4J_EXPORT ExtraArguments {
|
||||||
private:
|
private:
|
||||||
std::vector<double> _fpArgs;
|
std::vector<double> _fpArgs;
|
||||||
|
@ -54,7 +54,7 @@ namespace nd4j {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void* argumentsAsT(Nd4jLong offset = 0);
|
void* argumentsAsT(Nd4jLong offset = 0);
|
||||||
|
|
||||||
void* argumentsAsT(nd4j::DataType dataType, Nd4jLong offset = 0);
|
void* argumentsAsT(sd::DataType dataType, Nd4jLong offset = 0);
|
||||||
|
|
||||||
size_t length();
|
size_t length();
|
||||||
};
|
};
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <dll.h>
|
#include <system/dll.h>
|
||||||
#include <array/DataBuffer.h>
|
#include <array/DataBuffer.h>
|
||||||
#include <array/DataType.h>
|
#include <array/DataType.h>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -26,7 +26,7 @@
|
||||||
#ifndef LIBND4J_INTEROPDATABUFFER_H
|
#ifndef LIBND4J_INTEROPDATABUFFER_H
|
||||||
#define LIBND4J_INTEROPDATABUFFER_H
|
#define LIBND4J_INTEROPDATABUFFER_H
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
/**
|
/**
|
||||||
* This class is a wrapper for DataBuffer, suitable for sharing DataBuffer between front-end and back-end languages
|
* This class is a wrapper for DataBuffer, suitable for sharing DataBuffer between front-end and back-end languages
|
||||||
*/
|
*/
|
||||||
|
@ -37,7 +37,7 @@ namespace nd4j {
|
||||||
public:
|
public:
|
||||||
InteropDataBuffer(InteropDataBuffer &dataBuffer, uint64_t length, uint64_t offset);
|
InteropDataBuffer(InteropDataBuffer &dataBuffer, uint64_t length, uint64_t offset);
|
||||||
InteropDataBuffer(std::shared_ptr<DataBuffer> databuffer);
|
InteropDataBuffer(std::shared_ptr<DataBuffer> databuffer);
|
||||||
InteropDataBuffer(size_t elements, nd4j::DataType dtype, bool allocateBoth);
|
InteropDataBuffer(size_t elements, sd::DataType dtype, bool allocateBoth);
|
||||||
~InteropDataBuffer() = default;
|
~InteropDataBuffer() = default;
|
||||||
|
|
||||||
#ifndef __JAVACPP_HACK__
|
#ifndef __JAVACPP_HACK__
|
||||||
|
|
|
@ -17,11 +17,11 @@
|
||||||
#ifndef NDARRAY_H
|
#ifndef NDARRAY_H
|
||||||
#define NDARRAY_H
|
#define NDARRAY_H
|
||||||
|
|
||||||
#include <dll.h>
|
#include <system/dll.h>
|
||||||
#include <initializer_list>
|
#include <initializer_list>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <shape.h>
|
#include <helpers/shape.h>
|
||||||
#include "NativeOpExecutioner.h"
|
#include "legacy/NativeOpExecutioner.h"
|
||||||
#include <indexing/NDIndex.h>
|
#include <indexing/NDIndex.h>
|
||||||
#include <indexing/IndicesList.h>
|
#include <indexing/IndicesList.h>
|
||||||
#include <graph/Intervals.h>
|
#include <graph/Intervals.h>
|
||||||
|
@ -32,13 +32,13 @@
|
||||||
#include <array/ArrayType.h>
|
#include <array/ArrayType.h>
|
||||||
#include <array/ResultSet.h>
|
#include <array/ResultSet.h>
|
||||||
#include <helpers/ShapeBuilders.h>
|
#include <helpers/ShapeBuilders.h>
|
||||||
#include <op_enums.h>
|
#include <system/op_enums.h>
|
||||||
#include <ops/BroadcastOpsTuple.h>
|
#include <ops/BroadcastOpsTuple.h>
|
||||||
#include <ops/BroadcastBoolOpsTuple.h>
|
#include <ops/BroadcastBoolOpsTuple.h>
|
||||||
#include <ops/BroadcastIntOpsTuple.h>
|
#include <ops/BroadcastIntOpsTuple.h>
|
||||||
#include <array/ExtraArguments.h>
|
#include <array/ExtraArguments.h>
|
||||||
#include <Status.h>
|
#include <graph/Status.h>
|
||||||
#include <ShapeDescriptor.h>
|
#include <array/ShapeDescriptor.h>
|
||||||
#include <helpers/ConstantShapeHelper.h>
|
#include <helpers/ConstantShapeHelper.h>
|
||||||
#include <array/DataBuffer.h>
|
#include <array/DataBuffer.h>
|
||||||
#include <execution/AffinityManager.h>
|
#include <execution/AffinityManager.h>
|
||||||
|
@ -47,7 +47,7 @@
|
||||||
#include <memory/MemoryCounter.h>
|
#include <memory/MemoryCounter.h>
|
||||||
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
|
|
||||||
template <typename T, typename = typename std::enable_if<DataTypeUtils::scalarTypesForNDarray<T>::value>::type>
|
template <typename T, typename = typename std::enable_if<DataTypeUtils::scalarTypesForNDarray<T>::value>::type>
|
||||||
ND4J_EXPORT NDArray operator+(const NDArray& arr, const T& scalar);
|
ND4J_EXPORT NDArray operator+(const NDArray& arr, const T& scalar);
|
||||||
|
@ -116,7 +116,7 @@ namespace nd4j {
|
||||||
void templatedSet(void *buffer, const Nd4jLong xOffset, const void *value);
|
void templatedSet(void *buffer, const Nd4jLong xOffset, const void *value);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void templatedSet(void *buffer, const Nd4jLong xOfsset, nd4j::DataType dtype, const void *value);
|
void templatedSet(void *buffer, const Nd4jLong xOfsset, sd::DataType dtype, const void *value);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void templatedAssign(void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const;
|
void templatedAssign(void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const;
|
||||||
|
@ -161,7 +161,7 @@ namespace nd4j {
|
||||||
/**
|
/**
|
||||||
* pointer on device launch context (with all data needed there).
|
* pointer on device launch context (with all data needed there).
|
||||||
*/
|
*/
|
||||||
nd4j::LaunchContext * _context = nd4j::LaunchContext::defaultContext();
|
sd::LaunchContext * _context = sd::LaunchContext::defaultContext();
|
||||||
|
|
||||||
// indicates if array's buffer is within workspace
|
// indicates if array's buffer is within workspace
|
||||||
bool _isAttached = false;
|
bool _isAttached = false;
|
||||||
|
@ -174,7 +174,7 @@ namespace nd4j {
|
||||||
/**
|
/**
|
||||||
* type of array elements
|
* type of array elements
|
||||||
*/
|
*/
|
||||||
nd4j::DataType _dataType = FLOAT32;
|
sd::DataType _dataType = FLOAT32;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* deviceID where this NDArray belongs to
|
* deviceID where this NDArray belongs to
|
||||||
|
@ -191,72 +191,72 @@ namespace nd4j {
|
||||||
* do not allocate memory, memory for array is passed from outside
|
* do not allocate memory, memory for array is passed from outside
|
||||||
*/
|
*/
|
||||||
#ifndef __JAVACPP_HACK__
|
#ifndef __JAVACPP_HACK__
|
||||||
NDArray(std::shared_ptr<DataBuffer> buffer, const ShapeDescriptor& descriptor, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const Nd4jLong offset = 0);
|
NDArray(std::shared_ptr<DataBuffer> buffer, const ShapeDescriptor& descriptor, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const Nd4jLong offset = 0);
|
||||||
|
|
||||||
NDArray(std::shared_ptr<DataBuffer> buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
NDArray(std::shared_ptr<DataBuffer> buffer, const char order, const std::vector<Nd4jLong> &shape, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This contructors create scalar array containing string utf8
|
* This contructors create scalar array containing string utf8
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
NDArray(const char* str, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext())
|
NDArray(const char* str, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext* context = sd::LaunchContext::defaultContext())
|
||||||
: NDArray(std::string(str), dtype, context) {
|
: NDArray(std::string(str), dtype, context) {
|
||||||
}
|
}
|
||||||
NDArray(const std::string& string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
NDArray(const std::string& string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This contructors create scalar array containing string utf16
|
* This contructors create scalar array containing string utf16
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
NDArray(const char16_t* u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext())
|
NDArray(const char16_t* u16string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext())
|
||||||
: NDArray(std::u16string(u16string), dtype, context) {
|
: NDArray(std::u16string(u16string), dtype, context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
NDArray(const std::u16string& u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
NDArray(const std::u16string& u16string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This contructors create scalar array containing string utf32
|
* This contructors create scalar array containing string utf32
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
NDArray(const char32_t* u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext())
|
NDArray(const char32_t* u32string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext())
|
||||||
: NDArray(std::u32string(u32string), dtype, context) {
|
: NDArray(std::u32string(u32string), dtype, context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
NDArray(const std::u32string& u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
NDArray(const std::u32string& u32string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This contructors create array from vector of utf8 strings
|
* This contructors create array from vector of utf8 strings
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char*>& strings, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::string>& string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::string>& string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This contructors create array from vector of utf16 strings
|
* This contructors create array from vector of utf16 strings
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This contructors create array from vector of utf32 strings
|
* This contructors create array from vector of utf32 strings
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* do not allocate memory, memory for array is passed from outside
|
* do not allocate memory, memory for array is passed from outside
|
||||||
*/
|
*/
|
||||||
NDArray(void *buffer, Nd4jLong* shapeInfo, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
|
NDArray(void *buffer, Nd4jLong* shapeInfo, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* do not allocate memory, memory for array is passed from outside
|
* do not allocate memory, memory for array is passed from outside
|
||||||
* we suppose the content of both (device and host) buffers is identical
|
* we suppose the content of both (device and host) buffers is identical
|
||||||
*/
|
*/
|
||||||
NDArray(void *buffer, void *bufferD, Nd4jLong* shapeInfo, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false, const bool isBuffDAlloc = false);
|
NDArray(void *buffer, void *bufferD, Nd4jLong* shapeInfo, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool isBuffAlloc = false, const bool isBuffDAlloc = false);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* copy constructor
|
* copy constructor
|
||||||
|
@ -271,34 +271,34 @@ namespace nd4j {
|
||||||
/**
|
/**
|
||||||
* constructor, create array stored at given workspace
|
* constructor, create array stored at given workspace
|
||||||
*/
|
*/
|
||||||
NDArray(nd4j::LaunchContext * context);
|
NDArray(sd::LaunchContext * context);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
||||||
*/
|
*/
|
||||||
NDArray(Nd4jLong* shapeInfo, const bool copyStrides = false, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
NDArray(Nd4jLong* shapeInfo, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
||||||
* set dtype as array type
|
* set dtype as array type
|
||||||
*/
|
*/
|
||||||
NDArray(Nd4jLong* shapeInfo, const nd4j::DataType dtype, const bool copyStrides = false, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
NDArray(Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this constructor creates new array using shape information contained in vector argument
|
* this constructor creates new array using shape information contained in vector argument
|
||||||
*/
|
*/
|
||||||
NDArray(const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype = DOUBLE, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
NDArray(const char order, const std::vector<Nd4jLong> &shape, sd::DataType dtype = DOUBLE, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This constructor creates new array with elements copied from data and using shape information stored in shape, elements from data will be casted to dtype
|
* This constructor creates new array with elements copied from data and using shape information stored in shape, elements from data will be casted to dtype
|
||||||
*/
|
*/
|
||||||
NDArray(const char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, nd4j::DataType dtype = DOUBLE, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
NDArray(const char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, sd::DataType dtype = DOUBLE, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this constructor creates new array using given buffer (without memory allocation) and shape information stored in shape
|
* this constructor creates new array using given buffer (without memory allocation) and shape information stored in shape
|
||||||
*/
|
*/
|
||||||
NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
|
NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, sd::DataType dtype, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns new array with the same shape & data type
|
* This method returns new array with the same shape & data type
|
||||||
|
@ -310,19 +310,19 @@ namespace nd4j {
|
||||||
* This method returns new uninitialized array with the same shape & data type
|
* This method returns new uninitialized array with the same shape & data type
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
NDArray ulike();
|
NDArray ulike() const;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this constructor creates new NDArray with shape matching "other" array,
|
* this constructor creates new NDArray with shape matching "other" array,
|
||||||
* doesn't copy "other" elements into new array !!!
|
* doesn't copy "other" elements into new array !!!
|
||||||
*/
|
*/
|
||||||
explicit NDArray(const NDArray* other, const bool copyStrides = false, nd4j::LaunchContext* context = nd4j::LaunchContext ::defaultContext());
|
explicit NDArray(const NDArray* other, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this constructor creates scalar(and set its value = 0) or empty array depending on bool argument isScalar
|
* this constructor creates scalar(and set its value = 0) or empty array depending on bool argument isScalar
|
||||||
*/
|
*/
|
||||||
NDArray(nd4j::DataType dtype, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isScalar = true);
|
NDArray(sd::DataType dtype, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool isScalar = true);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method blocks until asynchronous operation finishes
|
* This method blocks until asynchronous operation finishes
|
||||||
|
@ -392,7 +392,7 @@ namespace nd4j {
|
||||||
void operator delete(void* p);
|
void operator delete(void* p);
|
||||||
|
|
||||||
|
|
||||||
void setContext(nd4j::LaunchContext * context);
|
void setContext(sd::LaunchContext * context);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* create a new array by replicating current array by repeats times along given dimension
|
* create a new array by replicating current array by repeats times along given dimension
|
||||||
|
@ -438,7 +438,7 @@ namespace nd4j {
|
||||||
/**
|
/**
|
||||||
* returns _context
|
* returns _context
|
||||||
*/
|
*/
|
||||||
nd4j::LaunchContext * getContext() const {
|
sd::LaunchContext * getContext() const {
|
||||||
return _context;
|
return _context;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -626,17 +626,17 @@ namespace nd4j {
|
||||||
* keepDims - if true then put unities in place of reduced dimensions
|
* keepDims - if true then put unities in place of reduced dimensions
|
||||||
*/
|
*/
|
||||||
|
|
||||||
NDArray reduceAlongDimension(nd4j::reduce::FloatOps op, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
NDArray reduceAlongDimension(sd::reduce::FloatOps op, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
||||||
NDArray reduceAlongDimension(nd4j::reduce::FloatOps op, const std::initializer_list<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
NDArray reduceAlongDimension(sd::reduce::FloatOps op, const std::initializer_list<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
||||||
|
|
||||||
NDArray reduceAlongDimension(nd4j::reduce::SameOps op, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
NDArray reduceAlongDimension(sd::reduce::SameOps op, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
||||||
NDArray reduceAlongDimension(nd4j::reduce::SameOps op, const std::initializer_list<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
NDArray reduceAlongDimension(sd::reduce::SameOps op, const std::initializer_list<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
||||||
|
|
||||||
NDArray reduceAlongDimension(nd4j::reduce::BoolOps op, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
NDArray reduceAlongDimension(sd::reduce::BoolOps op, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
||||||
NDArray reduceAlongDimension(nd4j::reduce::BoolOps op, const std::initializer_list<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
NDArray reduceAlongDimension(sd::reduce::BoolOps op, const std::initializer_list<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
||||||
|
|
||||||
NDArray reduceAlongDimension(nd4j::reduce::LongOps op, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
NDArray reduceAlongDimension(sd::reduce::LongOps op, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
||||||
NDArray reduceAlongDimension(nd4j::reduce::LongOps op, const std::initializer_list<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
NDArray reduceAlongDimension(sd::reduce::LongOps op, const std::initializer_list<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* method reduces array by excluding its shapes along dimensions present in given dimensions vector
|
* method reduces array by excluding its shapes along dimensions present in given dimensions vector
|
||||||
|
@ -645,37 +645,37 @@ namespace nd4j {
|
||||||
* keepDims - if true then put unities in place of reduced dimensions
|
* keepDims - if true then put unities in place of reduced dimensions
|
||||||
* extras - extra parameters
|
* extras - extra parameters
|
||||||
*/
|
*/
|
||||||
void reduceAlongDimension(nd4j::reduce::FloatOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const;
|
void reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const;
|
||||||
void reduceAlongDimension(nd4j::reduce::SameOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const;
|
void reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const;
|
||||||
void reduceAlongDimension(nd4j::reduce::BoolOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const;
|
void reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const;
|
||||||
void reduceAlongDimension(nd4j::reduce::LongOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const;
|
void reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* return variance of array elements set
|
* return variance of array elements set
|
||||||
* biasCorrected - if true bias correction will be applied
|
* biasCorrected - if true bias correction will be applied
|
||||||
*/
|
*/
|
||||||
NDArray varianceNumber(nd4j::variance::Ops op, bool biasCorrected = true);
|
NDArray varianceNumber(sd::variance::Ops op, bool biasCorrected = true);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* apply scalar operation to array
|
* apply scalar operation to array
|
||||||
* extraParams - extra parameters for operation
|
* extraParams - extra parameters for operation
|
||||||
* returns scalar array
|
* returns scalar array
|
||||||
*/
|
*/
|
||||||
NDArray reduceNumber(nd4j::reduce::FloatOps ops, void *extraParams = nullptr) const;
|
NDArray reduceNumber(sd::reduce::FloatOps ops, void *extraParams = nullptr) const;
|
||||||
NDArray reduceNumber(nd4j::reduce::SameOps ops, void *extraParams = nullptr) const;
|
NDArray reduceNumber(sd::reduce::SameOps ops, void *extraParams = nullptr) const;
|
||||||
NDArray reduceNumber(nd4j::reduce::BoolOps ops, void *extraParams = nullptr) const;
|
NDArray reduceNumber(sd::reduce::BoolOps ops, void *extraParams = nullptr) const;
|
||||||
NDArray reduceNumber(nd4j::reduce::LongOps ops, void *extraParams = nullptr) const;
|
NDArray reduceNumber(sd::reduce::LongOps ops, void *extraParams = nullptr) const;
|
||||||
|
|
||||||
void reduceNumber(nd4j::reduce::FloatOps ops, NDArray& target, void *extraParams = nullptr) const;
|
void reduceNumber(sd::reduce::FloatOps ops, NDArray& target, void *extraParams = nullptr) const;
|
||||||
void reduceNumber(nd4j::reduce::SameOps ops, NDArray& target, void *extraParams = nullptr) const;
|
void reduceNumber(sd::reduce::SameOps ops, NDArray& target, void *extraParams = nullptr) const;
|
||||||
void reduceNumber(nd4j::reduce::BoolOps ops, NDArray& target, void *extraParams = nullptr) const;
|
void reduceNumber(sd::reduce::BoolOps ops, NDArray& target, void *extraParams = nullptr) const;
|
||||||
void reduceNumber(nd4j::reduce::LongOps ops, NDArray& target, void *extraParams = nullptr) const;
|
void reduceNumber(sd::reduce::LongOps ops, NDArray& target, void *extraParams = nullptr) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* returns element index which corresponds to some condition imposed by operation
|
* returns element index which corresponds to some condition imposed by operation
|
||||||
* extraParams - extra parameters for operation
|
* extraParams - extra parameters for operation
|
||||||
*/
|
*/
|
||||||
NDArray indexReduceNumber(nd4j::indexreduce::Ops op, ExtraArguments *extraParams = nullptr);
|
NDArray indexReduceNumber(sd::indexreduce::Ops op, ExtraArguments *extraParams = nullptr);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* returns index of max element in a given array (optionally: along given dimension(s))
|
* returns index of max element in a given array (optionally: along given dimension(s))
|
||||||
|
@ -687,31 +687,31 @@ namespace nd4j {
|
||||||
void makeBothActual() const { syncToDevice(); syncToHost(); }
|
void makeBothActual() const { syncToDevice(); syncToHost(); }
|
||||||
|
|
||||||
|
|
||||||
void applyTransform(nd4j::transform::FloatOps op, NDArray& target, ExtraArguments *extraParams = nullptr);
|
void applyTransform(sd::transform::FloatOps op, NDArray& target, ExtraArguments *extraParams = nullptr);
|
||||||
void applyTransform(nd4j::transform::SameOps op, NDArray& target, ExtraArguments *extraParams = nullptr);
|
void applyTransform(sd::transform::SameOps op, NDArray& target, ExtraArguments *extraParams = nullptr);
|
||||||
void applyTransform(nd4j::transform::AnyOps op, NDArray& target, ExtraArguments *extraParams = nullptr);
|
void applyTransform(sd::transform::AnyOps op, NDArray& target, ExtraArguments *extraParams = nullptr);
|
||||||
void applyTransform(nd4j::transform::BoolOps op, NDArray& target, ExtraArguments *extraParams = nullptr);
|
void applyTransform(sd::transform::BoolOps op, NDArray& target, ExtraArguments *extraParams = nullptr);
|
||||||
void applyTransform(nd4j::transform::StrictOps op, NDArray& target, ExtraArguments *extraParams = nullptr);
|
void applyTransform(sd::transform::StrictOps op, NDArray& target, ExtraArguments *extraParams = nullptr);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* apply OpName transformation to this array and store result in new array to be returned
|
* apply OpName transformation to this array and store result in new array to be returned
|
||||||
* extraParams - extra parameters for operation
|
* extraParams - extra parameters for operation
|
||||||
*/
|
*/
|
||||||
NDArray transform(nd4j::transform::FloatOps op, void *extraParams = nullptr) const &;
|
NDArray transform(sd::transform::FloatOps op, void *extraParams = nullptr) const &;
|
||||||
NDArray transform(nd4j::transform::SameOps op, void *extraParams = nullptr) const &;
|
NDArray transform(sd::transform::SameOps op, void *extraParams = nullptr) const &;
|
||||||
NDArray transform(nd4j::transform::BoolOps op, void *extraParams = nullptr) const &;
|
NDArray transform(sd::transform::BoolOps op, void *extraParams = nullptr) const &;
|
||||||
NDArray transform(nd4j::transform::StrictOps op, void *extraParams = nullptr) const &;
|
NDArray transform(sd::transform::StrictOps op, void *extraParams = nullptr) const &;
|
||||||
NDArray transform(nd4j::transform::FloatOps op, void *extraParams = nullptr) &&;
|
NDArray transform(sd::transform::FloatOps op, void *extraParams = nullptr) &&;
|
||||||
NDArray transform(nd4j::transform::SameOps op, void *extraParams = nullptr) &&;
|
NDArray transform(sd::transform::SameOps op, void *extraParams = nullptr) &&;
|
||||||
NDArray transform(nd4j::transform::BoolOps op, void *extraParams = nullptr) &&;
|
NDArray transform(sd::transform::BoolOps op, void *extraParams = nullptr) &&;
|
||||||
NDArray transform(nd4j::transform::StrictOps op, void *extraParams = nullptr) &&;
|
NDArray transform(sd::transform::StrictOps op, void *extraParams = nullptr) &&;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* apply pairwise OpName transformation based on "this" and "other" arras elements, store result in this array
|
* apply pairwise OpName transformation based on "this" and "other" arras elements, store result in this array
|
||||||
* other - second array necessary for pairwise operation
|
* other - second array necessary for pairwise operation
|
||||||
* extraParams - extra parameters for operation
|
* extraParams - extra parameters for operation
|
||||||
*/
|
*/
|
||||||
void applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray& other, ExtraArguments *extraParams = nullptr);
|
void applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, ExtraArguments *extraParams = nullptr);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* apply pairwise OpName transformation based on "this" and "other" arras elements, store result in target array
|
* apply pairwise OpName transformation based on "this" and "other" arras elements, store result in target array
|
||||||
|
@ -719,11 +719,11 @@ namespace nd4j {
|
||||||
* target - where to store result
|
* target - where to store result
|
||||||
* extraParams - extra parameters for operation
|
* extraParams - extra parameters for operation
|
||||||
*/
|
*/
|
||||||
void applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray& other, NDArray& target, ExtraArguments *extraParams = nullptr) const;
|
void applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, NDArray& target, ExtraArguments *extraParams = nullptr) const;
|
||||||
|
|
||||||
void applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray& other, NDArray& target, ExtraArguments *extraParams = nullptr) const;
|
void applyPairwiseTransform(sd::pairwise::BoolOps op, const NDArray& other, NDArray& target, ExtraArguments *extraParams = nullptr) const;
|
||||||
|
|
||||||
void applyPairwiseTransform(nd4j::pairwise::IntOps op, const NDArray& other, NDArray&target, ExtraArguments *extraParams = nullptr) const;
|
void applyPairwiseTransform(sd::pairwise::IntOps op, const NDArray& other, NDArray&target, ExtraArguments *extraParams = nullptr) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* apply operation which requires broadcasting, broadcast a smaller array (tad) along bigger one (this)
|
* apply operation which requires broadcasting, broadcast a smaller array (tad) along bigger one (this)
|
||||||
|
@ -732,23 +732,23 @@ namespace nd4j {
|
||||||
* target - where to store result
|
* target - where to store result
|
||||||
* extraParams - extra parameters for operation
|
* extraParams - extra parameters for operation
|
||||||
*/
|
*/
|
||||||
void applyBroadcast(nd4j::broadcast::Ops op, const std::initializer_list<int> dimensions, const NDArray& tad, NDArray& target, ExtraArguments* extraArgs = nullptr);
|
void applyBroadcast(sd::broadcast::Ops op, const std::initializer_list<int> dimensions, const NDArray& tad, NDArray& target, ExtraArguments* extraArgs = nullptr);
|
||||||
|
|
||||||
void applyBroadcast(nd4j::broadcast::Ops op, const std::vector<int> &dimensions, const NDArray &tad, NDArray &target, ExtraArguments *extraArgs = nullptr);
|
void applyBroadcast(sd::broadcast::Ops op, const std::vector<int> &dimensions, const NDArray &tad, NDArray &target, ExtraArguments *extraArgs = nullptr);
|
||||||
|
|
||||||
void applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector<int> &dimensions, const NDArray &tad, NDArray &target, ExtraArguments *extraArgs = nullptr);
|
void applyBroadcast(sd::broadcast::BoolOps op, const std::vector<int> &dimensions, const NDArray &tad, NDArray &target, ExtraArguments *extraArgs = nullptr);
|
||||||
|
|
||||||
void applyBroadcast(nd4j::broadcast::IntOps op, const std::vector<int> &dimensions, const NDArray& tad, NDArray &target, ExtraArguments *extraArgs = nullptr);
|
void applyBroadcast(sd::broadcast::IntOps op, const std::vector<int> &dimensions, const NDArray& tad, NDArray &target, ExtraArguments *extraArgs = nullptr);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* apply operation which requires broadcasting, broadcast one tensor along another, also this method checks the possibility of broadcasting
|
* apply operation which requires broadcasting, broadcast one tensor along another, also this method checks the possibility of broadcasting
|
||||||
* other - input array
|
* other - input array
|
||||||
* extraParams - extra parameters for operation
|
* extraParams - extra parameters for operation
|
||||||
*/
|
*/
|
||||||
NDArray applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs = nullptr) const &;
|
NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs = nullptr) const &;
|
||||||
NDArray applyTrueBroadcast(nd4j::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs = nullptr) const &;
|
NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs = nullptr) const &;
|
||||||
NDArray applyTrueBroadcast(nd4j::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs = nullptr) &&;
|
NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs = nullptr) &&;
|
||||||
NDArray applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs = nullptr) &&;
|
NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs = nullptr) &&;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* apply operation which requires broadcasting, broadcast one tensor along another, also this method checks the possibility of broadcasting
|
* apply operation which requires broadcasting, broadcast one tensor along another, also this method checks the possibility of broadcasting
|
||||||
|
@ -757,11 +757,11 @@ namespace nd4j {
|
||||||
* checkTargetShape - if true check whether target shape is suitable for broadcasting
|
* checkTargetShape - if true check whether target shape is suitable for broadcasting
|
||||||
* extraParams - extra parameters for operation
|
* extraParams - extra parameters for operation
|
||||||
*/
|
*/
|
||||||
void applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const;
|
void applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const;
|
||||||
|
|
||||||
void applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const;
|
void applyTrueBroadcast(sd::BroadcastBoolOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const;
|
||||||
|
|
||||||
void applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const;
|
void applyTrueBroadcast(sd::BroadcastIntOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -771,13 +771,13 @@ namespace nd4j {
|
||||||
* extraParams - extra parameters for operation
|
* extraParams - extra parameters for operation
|
||||||
*/
|
*/
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void applyScalar(nd4j::scalar::Ops op, const T scalar, NDArray& target, ExtraArguments *extraParams = nullptr);
|
void applyScalar(sd::scalar::Ops op, const T scalar, NDArray& target, ExtraArguments *extraParams = nullptr);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void applyScalar(nd4j::scalar::BoolOps op, const T scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const;
|
void applyScalar(sd::scalar::BoolOps op, const T scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void applyScalar(nd4j::scalar::IntOps op, const T scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const;
|
void applyScalar(sd::scalar::IntOps op, const T scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* apply a scalar operation to an array
|
* apply a scalar operation to an array
|
||||||
|
@ -785,11 +785,11 @@ namespace nd4j {
|
||||||
* target - where to store result
|
* target - where to store result
|
||||||
* extraParams - extra parameters for operation
|
* extraParams - extra parameters for operation
|
||||||
*/
|
*/
|
||||||
void applyScalarArr(nd4j::scalar::Ops op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams = nullptr);
|
void applyScalarArr(sd::scalar::Ops op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams = nullptr);
|
||||||
|
|
||||||
void applyScalarArr(nd4j::scalar::BoolOps op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const;
|
void applyScalarArr(sd::scalar::BoolOps op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const;
|
||||||
|
|
||||||
void applyScalarArr(nd4j::scalar::IntOps op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const;
|
void applyScalarArr(sd::scalar::IntOps op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const;
|
||||||
|
|
||||||
#if defined(__CUDABLAS__) //&& defined(BUILD_TESTS)
|
#if defined(__CUDABLAS__) //&& defined(BUILD_TESTS)
|
||||||
template <typename Lambda>
|
template <typename Lambda>
|
||||||
|
@ -840,7 +840,7 @@ namespace nd4j {
|
||||||
* dimensions - vector of dimensions to reduce along
|
* dimensions - vector of dimensions to reduce along
|
||||||
* extraArgs - extra parameters for operation
|
* extraArgs - extra parameters for operation
|
||||||
*/
|
*/
|
||||||
NDArray applyIndexReduce(nd4j::indexreduce::Ops op, const std::vector<int>& dimensions, const ExtraArguments *extraParams = nullptr) const;
|
NDArray applyIndexReduce(sd::indexreduce::Ops op, const std::vector<int>& dimensions, const ExtraArguments *extraParams = nullptr) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* reduces dimensions in array relying on index operation OpName
|
* reduces dimensions in array relying on index operation OpName
|
||||||
|
@ -848,14 +848,14 @@ namespace nd4j {
|
||||||
* dimensions - vector of dimensions to reduce along
|
* dimensions - vector of dimensions to reduce along
|
||||||
* extraArgs - extra parameters for operation
|
* extraArgs - extra parameters for operation
|
||||||
*/
|
*/
|
||||||
void applyIndexReduce(nd4j::indexreduce::Ops op, NDArray& target, const std::vector<int>& dimensions, const ExtraArguments *extraParams = nullptr) const;
|
void applyIndexReduce(sd::indexreduce::Ops op, NDArray& target, const std::vector<int>& dimensions, const ExtraArguments *extraParams = nullptr) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* apply reduce3 operation OpName to this and other array, return result in new output array
|
* apply reduce3 operation OpName to this and other array, return result in new output array
|
||||||
* other - input array
|
* other - input array
|
||||||
* extraArgs - extra parameters for operation
|
* extraArgs - extra parameters for operation
|
||||||
*/
|
*/
|
||||||
NDArray applyReduce3(nd4j::reduce3::Ops op, const NDArray& other, const ExtraArguments* extraParams = nullptr) const;
|
NDArray applyReduce3(sd::reduce3::Ops op, const NDArray& other, const ExtraArguments* extraParams = nullptr) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* apply reduce3 operation OpName to this and other array, return result in new output array
|
* apply reduce3 operation OpName to this and other array, return result in new output array
|
||||||
|
@ -863,7 +863,7 @@ namespace nd4j {
|
||||||
* dimensions - vector of dimensions to reduce along (tads not axis)
|
* dimensions - vector of dimensions to reduce along (tads not axis)
|
||||||
* extraArgs - extra parameters for operation
|
* extraArgs - extra parameters for operation
|
||||||
*/
|
*/
|
||||||
NDArray applyAllReduce3(nd4j::reduce3::Ops op, const NDArray& other, const std::vector<int>& dimensions, const ExtraArguments* extraParams = nullptr) const;
|
NDArray applyAllReduce3(sd::reduce3::Ops op, const NDArray& other, const std::vector<int>& dimensions, const ExtraArguments* extraParams = nullptr) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* apply reduce3 (exec) operation OpName to this and other array, return result in new output array
|
* apply reduce3 (exec) operation OpName to this and other array, return result in new output array
|
||||||
|
@ -871,18 +871,18 @@ namespace nd4j {
|
||||||
* dimensions - vector of dimensions to reduce along (same as reduceAlongDimension)
|
* dimensions - vector of dimensions to reduce along (same as reduceAlongDimension)
|
||||||
* extraArgs - extra parameters for operation
|
* extraArgs - extra parameters for operation
|
||||||
*/
|
*/
|
||||||
NDArray applyReduce3(nd4j::reduce3::Ops op, const NDArray& other, const std::vector<int>& dimensions, const ExtraArguments* extraParams = nullptr) const;
|
NDArray applyReduce3(sd::reduce3::Ops op, const NDArray& other, const std::vector<int>& dimensions, const ExtraArguments* extraParams = nullptr) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* returns variance along given dimensions
|
* returns variance along given dimensions
|
||||||
* biasCorrected - if true bias correction will be applied
|
* biasCorrected - if true bias correction will be applied
|
||||||
* dimensions - vector of dimensions to calculate variance along
|
* dimensions - vector of dimensions to calculate variance along
|
||||||
*/
|
*/
|
||||||
NDArray varianceAlongDimension(nd4j::variance::Ops op, const bool biasCorrected, const std::vector<int>& dimensions) const;
|
NDArray varianceAlongDimension(sd::variance::Ops op, const bool biasCorrected, const std::vector<int>& dimensions) const;
|
||||||
NDArray varianceAlongDimension(nd4j::variance::Ops op, const bool biasCorrected, const std::initializer_list<int>& dimensions) const;
|
NDArray varianceAlongDimension(sd::variance::Ops op, const bool biasCorrected, const std::initializer_list<int>& dimensions) const;
|
||||||
|
|
||||||
void varianceAlongDimension(nd4j::variance::Ops op, NDArray& target, const bool biasCorrected, const std::vector<int>& dimensions) const;
|
void varianceAlongDimension(sd::variance::Ops op, NDArray& target, const bool biasCorrected, const std::vector<int>& dimensions) const;
|
||||||
void varianceAlongDimension(nd4j::variance::Ops op, NDArray& target, const bool biasCorrected, const std::initializer_list<int>& dimensions) const;
|
void varianceAlongDimension(sd::variance::Ops op, NDArray& target, const bool biasCorrected, const std::initializer_list<int>& dimensions) const;
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -903,14 +903,6 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
void transposei();
|
void transposei();
|
||||||
|
|
||||||
/**
|
|
||||||
* return array pointing on certain range of this array
|
|
||||||
* index - the number of array to be returned among set of possible arrays
|
|
||||||
* dimensions - array of dimensions to point on
|
|
||||||
*/
|
|
||||||
NDArray tensorAlongDimension(Nd4jLong index, const std::initializer_list<int>& dimensions) const;
|
|
||||||
NDArray tensorAlongDimension(Nd4jLong index, const std::vector<int>& dimensions) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* returns the number of arrays pointing on specified dimension(s)
|
* returns the number of arrays pointing on specified dimension(s)
|
||||||
* dimensions - array of dimensions to point on
|
* dimensions - array of dimensions to point on
|
||||||
|
@ -1224,7 +1216,7 @@ namespace nd4j {
|
||||||
* set _shapeInfo
|
* set _shapeInfo
|
||||||
*/
|
*/
|
||||||
void setShapeInfo(const Nd4jLong *shapeInfo);
|
void setShapeInfo(const Nd4jLong *shapeInfo);
|
||||||
void setShapeInfo(const Nd4jLong *shapeInfo, const nd4j::DataType dtype);
|
void setShapeInfo(const Nd4jLong *shapeInfo, const sd::DataType dtype);
|
||||||
void setShapeInfo(const ShapeDescriptor& descriptor);
|
void setShapeInfo(const ShapeDescriptor& descriptor);
|
||||||
void setShapeInfo(const ConstantDataBuffer& shapeBuffer);
|
void setShapeInfo(const ConstantDataBuffer& shapeBuffer);
|
||||||
|
|
||||||
|
@ -1271,7 +1263,7 @@ namespace nd4j {
|
||||||
* set _shapeInfo
|
* set _shapeInfo
|
||||||
*/
|
*/
|
||||||
FORCEINLINE void setShapeInfo(Nd4jLong *shapeInfo);
|
FORCEINLINE void setShapeInfo(Nd4jLong *shapeInfo);
|
||||||
FORCEINLINE void setShapeInfo(Nd4jLong *shapeInfo, const nd4j::DataType dtype);
|
FORCEINLINE void setShapeInfo(Nd4jLong *shapeInfo, const sd::DataType dtype);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* returns the value of "dim" dimension
|
* returns the value of "dim" dimension
|
||||||
|
@ -1537,13 +1529,13 @@ void NDArray::setShapeInfo(Nd4jLong *shapeInfo) {
|
||||||
_length = shape::length(_shapeInfo);
|
_length = shape::length(_shapeInfo);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
_dataType = nd4j::DataType::INHERIT;
|
_dataType = sd::DataType::INHERIT;
|
||||||
_length = 0;
|
_length = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::setShapeInfo(Nd4jLong *shapeInfo, const nd4j::DataType dtype) {
|
void NDArray::setShapeInfo(Nd4jLong *shapeInfo, const sd::DataType dtype) {
|
||||||
auto buffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(shapeInfo);
|
auto buffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(shapeInfo);
|
||||||
_shapeInfo = buffer.primaryAsT<Nd4jLong>();
|
_shapeInfo = buffer.primaryAsT<Nd4jLong>();
|
||||||
_shapeInfoD = buffer.specialAsT<Nd4jLong>();
|
_shapeInfoD = buffer.specialAsT<Nd4jLong>();
|
||||||
|
@ -1556,7 +1548,7 @@ void NDArray::setShapeInfo(Nd4jLong *shapeInfo, const nd4j::DataType dtype) {
|
||||||
_length = shape::length(_shapeInfo);
|
_length = shape::length(_shapeInfo);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
_dataType = nd4j::DataType::INHERIT;
|
_dataType = sd::DataType::INHERIT;
|
||||||
_length = 0;
|
_length = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1702,7 +1694,7 @@ bool NDArray::isSameShape(const std::vector<Nd4jLong>& shape) const{
|
||||||
if (this->rankOf() != (int) shape.size())
|
if (this->rankOf() != (int) shape.size())
|
||||||
return false;
|
return false;
|
||||||
for (int e = 0; e < this->rankOf(); e++) {
|
for (int e = 0; e < this->rankOf(); e++) {
|
||||||
if (this->shapeOf()[e] != shape.at(e) && shape.at(e) != -1)
|
if (this->shapeOf()[e] != shape[e] && shape[e] != -1)
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
@ -1981,7 +1973,7 @@ Nd4jLong* NDArray::getSpecialShapeInfo() const{
|
||||||
|
|
||||||
#if defined(__CUDACC__) //&& defined(BUILD_TESTS)
|
#if defined(__CUDACC__) //&& defined(BUILD_TESTS)
|
||||||
// for CUDA we need stil stuff inline
|
// for CUDA we need stil stuff inline
|
||||||
#include "cuda/NDArrayLambda.hpp"
|
#include <array/NDArrayLambda.hXX>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,191 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 2018-09-16.
|
||||||
|
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef DEV_TESTS_NDARRAYFACTORY_H
|
||||||
|
#define DEV_TESTS_NDARRAYFACTORY_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <initializer_list>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
//#include <memory/Workspace.h>
|
||||||
|
#include <execution/LaunchContext.h>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
class ND4J_EXPORT NDArrayFactory {
|
||||||
|
private:
|
||||||
|
template <typename T>
|
||||||
|
static void memcpyFromVector(void *ptr, const std::vector<T> &vector);
|
||||||
|
public:
|
||||||
|
template <typename T>
|
||||||
|
static NDArray* empty_(sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
static NDArray* empty_(sd::DataType dataType, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static NDArray empty(sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
static NDArray empty(sd::DataType dataType, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static NDArray* valueOf(const std::initializer_list<Nd4jLong>& shape, T value, char order = 'c', sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static NDArray* valueOf(const std::vector<Nd4jLong>& shape, T value, char order = 'c', sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
static NDArray* valueOf(const std::vector<Nd4jLong>& shape, const NDArray& value, char order = 'c', sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static NDArray* linspace(T from, T to, Nd4jLong numElements);
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static NDArray* create_(const T value, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
static NDArray* create_(sd::DataType dtype, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static NDArray create(const T value, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
static NDArray create(sd::DataType dtype, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
template <typename T>
|
||||||
|
static NDArray create(DataType type, const T scalar, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static NDArray* vector(Nd4jLong length, T startingValue = (T) 0, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static NDArray* create_(char order, const std::vector<Nd4jLong> &shape, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
static NDArray* create_( char order, const std::vector<Nd4jLong> &shape, sd::DataType dataType, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static NDArray* create_(char order, const std::vector<Nd4jLong> &shape, const std::vector<T> &data, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static NDArray create(char order, const std::vector<Nd4jLong> &shape, const std::vector<T> &data, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static NDArray create(char order, const std::vector<Nd4jLong> &shape, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
static NDArray create(char order, const std::vector<Nd4jLong> &shape, sd::DataType dtype, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static NDArray create(const std::vector<T> &values, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
#ifndef __JAVACPP_HACK__
|
||||||
|
// this method only available out of javacpp
|
||||||
|
/**
|
||||||
|
* This constructor creates vector of T
|
||||||
|
*
|
||||||
|
* @param values
|
||||||
|
*/
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static NDArray create(char order, const std::initializer_list<Nd4jLong>& shape, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static NDArray create(T* buffer, char order, const std::initializer_list<Nd4jLong>& shape, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static NDArray create(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<T>& data, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method creates NDArray from .npy file
|
||||||
|
* @param fileName
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
static NDArray fromNpyFile(const char *fileName);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This factory create array from utf8 string
|
||||||
|
* @return NDArray default dataType UTF8
|
||||||
|
*/
|
||||||
|
static NDArray string(const char *string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
static NDArray* string_(const char *string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
static NDArray* string_(const std::string &string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
static NDArray string(const std::string& string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This factory create array from utf16 string
|
||||||
|
* @return NDArray default dataType UTF16
|
||||||
|
*/
|
||||||
|
static NDArray string(const char16_t* u16string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_(const char16_t* u16string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_(const std::u16string& u16string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray string(const std::u16string& u16string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This factory create array from utf32 string
|
||||||
|
* @return NDArray default dataType UTF32
|
||||||
|
*/
|
||||||
|
static NDArray string(const char32_t* u32string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_(const char32_t* u32string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_(const std::u32string& u32string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray string(const std::u32string& u32string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This factory create array from vector of utf8 strings
|
||||||
|
* @return NDArray default dataType UTF8
|
||||||
|
*/
|
||||||
|
static NDArray string( const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
static NDArray string( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
static NDArray string( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
static NDArray string( const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This factory create array from vector of utf16 strings
|
||||||
|
* @return NDArray default dataType UTF16
|
||||||
|
*/
|
||||||
|
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This factory create array from vector of utf32 strings
|
||||||
|
* @return NDArray default dataType UTF32
|
||||||
|
*/
|
||||||
|
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
||||||
|
|
||||||
|
|
||||||
|
static ResultSet createSetOfArrs(const Nd4jLong numOfArrs, const void* buffer, const Nd4jLong* shapeInfo, const Nd4jLong* offsets, sd::LaunchContext * context = sd::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif //DEV_TESTS_NDARRAYFACTORY_H
|
|
@ -17,17 +17,17 @@
|
||||||
#ifndef CUDA_LAMBDA_HELPER
|
#ifndef CUDA_LAMBDA_HELPER
|
||||||
#define CUDA_LAMBDA_HELPER
|
#define CUDA_LAMBDA_HELPER
|
||||||
|
|
||||||
#include <pointercast.h>
|
#include <system/pointercast.h>
|
||||||
#include <op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
#include <helpers/shape.h>
|
#include <helpers/shape.h>
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
static Nd4jLong __device__ __noinline__ __getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) {
|
static Nd4jLong __device__ __noinline__ getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) {
|
||||||
return shape::getIndexOffset(index, shapeInfo);
|
return shape::getIndexOffset(index, shapeInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Nd4jLong __device__ __noinline__ __length(Nd4jLong *shapeInfo) {
|
static Nd4jLong __device__ __noinline__ length(Nd4jLong *shapeInfo) {
|
||||||
return shape::length(shapeInfo);
|
return shape::length(shapeInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL
|
||||||
auto xOrder = shape::order(xShapeInfo);
|
auto xOrder = shape::order(xShapeInfo);
|
||||||
auto zOrder = shape::order(zShapeInfo);
|
auto zOrder = shape::order(zShapeInfo);
|
||||||
|
|
||||||
auto zLength = __length(zShapeInfo);
|
auto zLength = length(zShapeInfo);
|
||||||
|
|
||||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
|
||||||
|
@ -103,8 +103,8 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL
|
||||||
z[e * zEws] = lambda(x[e * xEws]);
|
z[e * zEws] = lambda(x[e * xEws]);
|
||||||
} else {
|
} else {
|
||||||
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
|
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
|
||||||
auto xOffset = __getIndexOffset(e, xShapeInfo);
|
auto xOffset = getIndexOffset(e, xShapeInfo);
|
||||||
auto zOffset = __getIndexOffset(e, zShapeInfo);
|
auto zOffset = getIndexOffset(e, zShapeInfo);
|
||||||
|
|
||||||
z[zOffset] = lambda(x[xOffset]);
|
z[zOffset] = lambda(x[xOffset]);
|
||||||
}
|
}
|
||||||
|
@ -123,7 +123,7 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz
|
||||||
auto xOrder = shape::order(xShapeInfo);
|
auto xOrder = shape::order(xShapeInfo);
|
||||||
auto zOrder = shape::order(zShapeInfo);
|
auto zOrder = shape::order(zShapeInfo);
|
||||||
|
|
||||||
auto zLength = __length(zShapeInfo);
|
auto zLength = length(zShapeInfo);
|
||||||
|
|
||||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
|
||||||
|
@ -132,8 +132,8 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz
|
||||||
z[e * zEws] = lambda(e, x[e * xEws]);
|
z[e * zEws] = lambda(e, x[e * xEws]);
|
||||||
} else {
|
} else {
|
||||||
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
|
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
|
||||||
auto xOffset = __getIndexOffset(e, xShapeInfo);
|
auto xOffset = getIndexOffset(e, xShapeInfo);
|
||||||
auto zOffset = __getIndexOffset(e, zShapeInfo);
|
auto zOffset = getIndexOffset(e, zShapeInfo);
|
||||||
|
|
||||||
z[zOffset] = lambda(e, x[xOffset]);
|
z[zOffset] = lambda(e, x[xOffset]);
|
||||||
}
|
}
|
||||||
|
@ -155,7 +155,7 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo,
|
||||||
auto yOrder = shape::order(yShapeInfo);
|
auto yOrder = shape::order(yShapeInfo);
|
||||||
auto zOrder = shape::order(zShapeInfo);
|
auto zOrder = shape::order(zShapeInfo);
|
||||||
|
|
||||||
auto zLength = __length(zShapeInfo);
|
auto zLength = length(zShapeInfo);
|
||||||
|
|
||||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
|
||||||
|
@ -164,9 +164,9 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo,
|
||||||
z[e * zEws] = lambda(e, x[e * xEws], y[e * yEws]);
|
z[e * zEws] = lambda(e, x[e * xEws], y[e * yEws]);
|
||||||
} else {
|
} else {
|
||||||
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
|
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
|
||||||
auto xOffset = __getIndexOffset(e, xShapeInfo);
|
auto xOffset = getIndexOffset(e, xShapeInfo);
|
||||||
auto yOffset = __getIndexOffset(e, yShapeInfo);
|
auto yOffset = getIndexOffset(e, yShapeInfo);
|
||||||
auto zOffset = __getIndexOffset(e, zShapeInfo);
|
auto zOffset = getIndexOffset(e, zShapeInfo);
|
||||||
|
|
||||||
z[zOffset] = lambda(e, x[xOffset], y[yOffset]);
|
z[zOffset] = lambda(e, x[xOffset], y[yOffset]);
|
||||||
}
|
}
|
||||||
|
@ -188,7 +188,7 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v
|
||||||
auto yOrder = shape::order(yShapeInfo);
|
auto yOrder = shape::order(yShapeInfo);
|
||||||
auto zOrder = shape::order(zShapeInfo);
|
auto zOrder = shape::order(zShapeInfo);
|
||||||
|
|
||||||
auto zLength = __length(zShapeInfo);
|
auto zLength = length(zShapeInfo);
|
||||||
|
|
||||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
|
||||||
|
@ -197,9 +197,9 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v
|
||||||
z[e * zEws] = lambda(x[e * xEws], y[e * yEws]);
|
z[e * zEws] = lambda(x[e * xEws], y[e * yEws]);
|
||||||
} else {
|
} else {
|
||||||
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
|
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
|
||||||
auto xOffset = __getIndexOffset(e, xShapeInfo);
|
auto xOffset = getIndexOffset(e, xShapeInfo);
|
||||||
auto yOffset = __getIndexOffset(e, yShapeInfo);
|
auto yOffset = getIndexOffset(e, yShapeInfo);
|
||||||
auto zOffset = __getIndexOffset(e, zShapeInfo);
|
auto zOffset = getIndexOffset(e, zShapeInfo);
|
||||||
|
|
||||||
z[zOffset] = lambda(x[xOffset], y[yOffset]);
|
z[zOffset] = lambda(x[xOffset], y[yOffset]);
|
||||||
}
|
}
|
||||||
|
@ -224,7 +224,7 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void*
|
||||||
auto yOrder = shape::order(yShapeInfo);
|
auto yOrder = shape::order(yShapeInfo);
|
||||||
auto zOrder = shape::order(zShapeInfo);
|
auto zOrder = shape::order(zShapeInfo);
|
||||||
|
|
||||||
auto zLength = __length(zShapeInfo);
|
auto zLength = length(zShapeInfo);
|
||||||
|
|
||||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
|
||||||
|
@ -233,10 +233,10 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void*
|
||||||
z[e * zEws] = lambda(w[e * wEws], x[e * xEws], y[e * yEws]);
|
z[e * zEws] = lambda(w[e * wEws], x[e * xEws], y[e * yEws]);
|
||||||
} else {
|
} else {
|
||||||
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
|
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
|
||||||
auto wOffset = __getIndexOffset(e, wShapeInfo);
|
auto wOffset = getIndexOffset(e, wShapeInfo);
|
||||||
auto xOffset = __getIndexOffset(e, xShapeInfo);
|
auto xOffset = getIndexOffset(e, xShapeInfo);
|
||||||
auto yOffset = __getIndexOffset(e, yShapeInfo);
|
auto yOffset = getIndexOffset(e, yShapeInfo);
|
||||||
auto zOffset = __getIndexOffset(e, zShapeInfo);
|
auto zOffset = getIndexOffset(e, zShapeInfo);
|
||||||
|
|
||||||
z[zOffset] = lambda(w[wOffset], x[xOffset], y[yOffset]);
|
z[zOffset] = lambda(w[wOffset], x[xOffset], y[yOffset]);
|
||||||
}
|
}
|
|
@ -26,25 +26,25 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <NDArray.h>
|
#include <array/NDArray.h>
|
||||||
#include <memory/Workspace.h>
|
#include <memory/Workspace.h>
|
||||||
#include <dll.h>
|
#include <system/dll.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
class ND4J_EXPORT NDArrayList {
|
class ND4J_EXPORT NDArrayList {
|
||||||
private:
|
private:
|
||||||
// workspace where chunks belong to
|
// workspace where chunks belong to
|
||||||
//nd4j::memory::Workspace* _workspace = nullptr;
|
//sd::memory::Workspace* _workspace = nullptr;
|
||||||
nd4j::LaunchContext * _context = nd4j::LaunchContext ::defaultContext();
|
sd::LaunchContext * _context = sd::LaunchContext ::defaultContext();
|
||||||
|
|
||||||
// numeric and symbolic ids of this list
|
// numeric and symbolic ids of this list
|
||||||
std::pair<int, int> _id;
|
std::pair<int, int> _id;
|
||||||
std::string _name;
|
std::string _name;
|
||||||
|
|
||||||
nd4j::DataType _dtype;
|
sd::DataType _dtype;
|
||||||
|
|
||||||
// stored chunks
|
// stored chunks
|
||||||
std::map<int, nd4j::NDArray*> _chunks;
|
MAP_IMPL<int, sd::NDArray*> _chunks;
|
||||||
|
|
||||||
// just a counter, for stored elements
|
// just a counter, for stored elements
|
||||||
std::atomic<int> _elements;
|
std::atomic<int> _elements;
|
||||||
|
@ -65,7 +65,7 @@ namespace nd4j {
|
||||||
NDArrayList(int height, bool expandable = false);
|
NDArrayList(int height, bool expandable = false);
|
||||||
~NDArrayList();
|
~NDArrayList();
|
||||||
|
|
||||||
nd4j::DataType dataType();
|
sd::DataType dataType();
|
||||||
|
|
||||||
NDArray* read(int idx);
|
NDArray* read(int idx);
|
||||||
NDArray* readRaw(int idx);
|
NDArray* readRaw(int idx);
|
||||||
|
@ -82,8 +82,8 @@ namespace nd4j {
|
||||||
|
|
||||||
std::pair<int,int>& id();
|
std::pair<int,int>& id();
|
||||||
std::string& name();
|
std::string& name();
|
||||||
//nd4j::memory::Workspace* workspace();
|
//sd::memory::Workspace* workspace();
|
||||||
nd4j::LaunchContext * context();
|
sd::LaunchContext * context();
|
||||||
NDArrayList* clone();
|
NDArrayList* clone();
|
||||||
|
|
||||||
bool equals(NDArrayList& other);
|
bool equals(NDArrayList& other);
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
//
|
//
|
||||||
// PLESE NOTE: It will delete all stored NDArrays upon destructor call
|
// PLESE NOTE: It will delete all stored NDArrays upon destructor call
|
||||||
//
|
//
|
||||||
// Created by raver119 on 07.09.17.
|
// @author raver119@gmail.com
|
||||||
//
|
//
|
||||||
|
|
||||||
#ifndef LIBND4J_RESULTSET_H
|
#ifndef LIBND4J_RESULTSET_H
|
||||||
|
@ -27,22 +27,27 @@
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <graph/generated/result_generated.h>
|
#include <graph/generated/result_generated.h>
|
||||||
#include <pointercast.h>
|
#include <system/pointercast.h>
|
||||||
#include <dll.h>
|
#include <system/dll.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
|
|
||||||
class NDArray; // forward declaration of template class NDArray
|
class NDArray; // forward declaration of template class NDArray
|
||||||
|
|
||||||
class ND4J_EXPORT ResultSet {
|
class ND4J_EXPORT ResultSet {
|
||||||
private:
|
private:
|
||||||
std::vector<nd4j::NDArray *> _content;
|
std::vector<sd::NDArray *> _content;
|
||||||
Nd4jStatus _status = ND4J_STATUS_OK;
|
Nd4jStatus _status = ND4J_STATUS_OK;
|
||||||
bool _removable = true;
|
bool _removable = true;
|
||||||
|
|
||||||
|
void delContent();
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// default constructor
|
explicit ResultSet();
|
||||||
ResultSet(const nd4j::graph::FlatResult* result = nullptr);
|
|
||||||
|
#ifndef __JAVACPP_HACK__
|
||||||
|
ResultSet(const sd::graph::FlatResult* result);
|
||||||
|
#endif
|
||||||
|
|
||||||
ResultSet(const ResultSet& other) noexcept;
|
ResultSet(const ResultSet& other) noexcept;
|
||||||
|
|
||||||
|
@ -57,9 +62,9 @@ namespace nd4j {
|
||||||
~ResultSet();
|
~ResultSet();
|
||||||
|
|
||||||
int size();
|
int size();
|
||||||
nd4j::NDArray* at(const unsigned long idx) const;
|
sd::NDArray* at(const unsigned long idx) const;
|
||||||
nd4j::NDArray* operator[](const unsigned long idx) const;
|
sd::NDArray* operator[](const unsigned long idx) const;
|
||||||
void push_back(nd4j::NDArray* array);
|
void push_back(sd::NDArray* array);
|
||||||
|
|
||||||
Nd4jStatus status();
|
Nd4jStatus status();
|
||||||
void setStatus(Nd4jStatus status);
|
void setStatus(Nd4jStatus status);
|
||||||
|
|
|
@ -23,12 +23,12 @@
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <dll.h>
|
#include <system/dll.h>
|
||||||
#include <pointercast.h>
|
#include <system/pointercast.h>
|
||||||
#include <DataType.h>
|
#include <array/DataType.h>
|
||||||
#include <initializer_list>
|
#include <initializer_list>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
|
|
||||||
class ND4J_EXPORT ShapeDescriptor {
|
class ND4J_EXPORT ShapeDescriptor {
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ class ND4J_EXPORT ShapeDescriptor {
|
||||||
public:
|
public:
|
||||||
ShapeDescriptor(const ShapeDescriptor &other);
|
ShapeDescriptor(const ShapeDescriptor &other);
|
||||||
ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype = true);
|
ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype = true);
|
||||||
explicit ShapeDescriptor(const Nd4jLong *shapeInfo, const nd4j::DataType dtypeOverride);
|
explicit ShapeDescriptor(const Nd4jLong *shapeInfo, const sd::DataType dtypeOverride);
|
||||||
explicit ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride);
|
explicit ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride);
|
||||||
explicit ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride, const Nd4jLong *orderOverride);
|
explicit ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride, const Nd4jLong *orderOverride);
|
||||||
explicit ShapeDescriptor(const DataType type, const Nd4jLong length);
|
explicit ShapeDescriptor(const DataType type, const Nd4jLong length);
|
||||||
|
@ -85,9 +85,19 @@ class ND4J_EXPORT ShapeDescriptor {
|
||||||
static ShapeDescriptor scalarDescriptor(const DataType type);
|
static ShapeDescriptor scalarDescriptor(const DataType type);
|
||||||
static ShapeDescriptor vectorDescriptor(const Nd4jLong length, const DataType type);
|
static ShapeDescriptor vectorDescriptor(const Nd4jLong length, const DataType type);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifndef __JAVACPP_HACK__
|
||||||
|
|
||||||
|
namespace std {
|
||||||
|
template<>
|
||||||
|
class ND4J_EXPORT hash<sd::ShapeDescriptor> {
|
||||||
|
public:
|
||||||
|
size_t operator()(const sd::ShapeDescriptor &k) const;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
#endif //DEV_TESTS_SHAPEDESCRIPTOR_H
|
#endif //DEV_TESTS_SHAPEDESCRIPTOR_H
|
||||||
|
|
|
@ -22,10 +22,10 @@
|
||||||
#define LIBND4J_SHAPELIST_H
|
#define LIBND4J_SHAPELIST_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <shape.h>
|
#include <helpers/shape.h>
|
||||||
#include <dll.h>
|
#include <system/dll.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
class ND4J_EXPORT ShapeList {
|
class ND4J_EXPORT ShapeList {
|
||||||
protected:
|
protected:
|
||||||
std::vector<Nd4jLong*> _shapes;
|
std::vector<Nd4jLong*> _shapes;
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#ifndef ND4J_SPACE_TYPE_H
|
#ifndef ND4J_SPACE_TYPE_H
|
||||||
#define ND4J_SPACE_TYPE_H
|
#define ND4J_SPACE_TYPE_H
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
enum SpaceType {
|
enum SpaceType {
|
||||||
CONTINUOUS = 1,
|
CONTINUOUS = 1,
|
||||||
COMPLEX = 2,
|
COMPLEX = 2,
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#ifndef LIBND4J_SPARSETYPE_H
|
#ifndef LIBND4J_SPARSETYPE_H
|
||||||
#define LIBND4J_SPARSETYPE_H
|
#define LIBND4J_SPARSETYPE_H
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
enum SparseType {
|
enum SparseType {
|
||||||
CSR = 1,
|
CSR = 1,
|
||||||
CSC = 2,
|
CSC = 2,
|
||||||
|
|
|
@ -22,9 +22,9 @@
|
||||||
#define DEV_TESTS_TADDESCRIPTOR_H
|
#define DEV_TESTS_TADDESCRIPTOR_H
|
||||||
|
|
||||||
#include "ShapeDescriptor.h"
|
#include "ShapeDescriptor.h"
|
||||||
#include <dll.h>
|
#include <system/dll.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
class ND4J_EXPORT TadDescriptor {
|
class ND4J_EXPORT TadDescriptor {
|
||||||
private:
|
private:
|
||||||
ShapeDescriptor _originalShape;
|
ShapeDescriptor _originalShape;
|
||||||
|
@ -53,9 +53,22 @@ namespace nd4j {
|
||||||
|
|
||||||
std::vector<int>& axis();
|
std::vector<int>& axis();
|
||||||
ShapeDescriptor& originalShape();
|
ShapeDescriptor& originalShape();
|
||||||
|
ShapeDescriptor const& originalShapeConst() const;
|
||||||
bool areUnitiesinShape() const;
|
bool areUnitiesinShape() const;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifndef __JAVACPP_HACK__
|
||||||
|
|
||||||
|
namespace std {
|
||||||
|
template<>
|
||||||
|
class ND4J_EXPORT hash<sd::TadDescriptor> {
|
||||||
|
public:
|
||||||
|
size_t operator()(const sd::TadDescriptor &k) const;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
#endif //DEV_TESTS_TADDESCRIPTOR_H
|
#endif //DEV_TESTS_TADDESCRIPTOR_H
|
||||||
|
|
|
@ -23,7 +23,7 @@
|
||||||
|
|
||||||
#include "ConstantDataBuffer.h"
|
#include "ConstantDataBuffer.h"
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
class ND4J_EXPORT TadPack {
|
class ND4J_EXPORT TadPack {
|
||||||
private:
|
private:
|
||||||
ConstantDataBuffer _tadShape;
|
ConstantDataBuffer _tadShape;
|
||||||
|
|
|
@ -19,10 +19,10 @@
|
||||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "../DataBuffer.h"
|
#include <array/DataBuffer.h>
|
||||||
#include <DataTypeUtils.h>
|
#include <array/DataTypeUtils.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
void DataBuffer::expand(const uint64_t size) {
|
void DataBuffer::expand(const uint64_t size) {
|
||||||
if (size > _lenInBytes) {
|
if (size > _lenInBytes) {
|
||||||
// allocate new buffer
|
// allocate new buffer
|
||||||
|
|
|
@ -17,15 +17,15 @@
|
||||||
#ifndef NDARRAY_CPP
|
#ifndef NDARRAY_CPP
|
||||||
#define NDARRAY_CPP
|
#define NDARRAY_CPP
|
||||||
|
|
||||||
#include "../NDArray.h"
|
#include <array/NDArray.h>
|
||||||
#include "../NDArrayFactory.h"
|
#include <array/NDArrayFactory.h>
|
||||||
#include "NativeOpExecutioner.h"
|
#include <legacy/NativeOpExecutioner.h>
|
||||||
#include <BroadcastPairwiseConverter.h>
|
#include <loops/BroadcastPairwiseConverter.h>
|
||||||
#include <memory/Workspace.h>
|
#include <memory/Workspace.h>
|
||||||
#include <memory/MemoryRegistrator.h>
|
#include <memory/MemoryRegistrator.h>
|
||||||
#include <ops.h>
|
#include <ops/ops.h>
|
||||||
#include <ops/gemm.h>
|
#include <ops/gemm.h>
|
||||||
#include <pointercast.h>
|
#include <system/pointercast.h>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <helpers/logger.h>
|
#include <helpers/logger.h>
|
||||||
|
@ -38,16 +38,16 @@
|
||||||
#include <helpers/ShapeUtils.h>
|
#include <helpers/ShapeUtils.h>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <helpers/ArrayUtils.h>
|
#include <helpers/ArrayUtils.h>
|
||||||
#include <MmulHelper.h>
|
#include <helpers/MmulHelper.h>
|
||||||
#include <helpers/threshold.h>
|
#include <helpers/threshold.h>
|
||||||
#include <exceptions/datatype_exception.h>
|
#include <exceptions/datatype_exception.h>
|
||||||
#include <exceptions/allocation_exception.h>
|
#include <exceptions/allocation_exception.h>
|
||||||
#include <helpers/ConstantTadHelper.h>
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
|
||||||
#include <NDArray.hpp>
|
#include <array/NDArray.hXX>
|
||||||
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
@ -95,22 +95,29 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& t
|
||||||
|
|
||||||
const bool areSameOffsets = shape::haveSameShapeAndStrides(getShapeInfo(), target.getShapeInfo());
|
const bool areSameOffsets = shape::haveSameShapeAndStrides(getShapeInfo(), target.getShapeInfo());
|
||||||
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
Nd4jLong coords[MAX_RANK];
|
|
||||||
|
int coords[MAX_RANK], temp;
|
||||||
|
|
||||||
for (auto i = start; i < stop; i++) {
|
for (auto i = start; i < stop; i++) {
|
||||||
shape::index2coords(i, target.getShapeInfo(), coords);
|
|
||||||
|
shape::index2coordsCPU(start, i, target.getShapeInfo(), coords);
|
||||||
const auto zOffset = shape::getOffset(target.getShapeInfo(), coords);
|
const auto zOffset = shape::getOffset(target.getShapeInfo(), coords);
|
||||||
|
|
||||||
// if( (row + upper < col) || (row + lower > col) )
|
// if( (row + upper < col) || (row + lower > col) )
|
||||||
if ((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1]))
|
if ((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1]))
|
||||||
z[zOffset] = value;
|
z[zOffset] = value;
|
||||||
else if (this != &target) { // when this and target are different arrays
|
else if (this != &target) { // when this and target are different arrays
|
||||||
if (xRank != zRank)
|
if (xRank != zRank) {
|
||||||
|
temp = coords[0];
|
||||||
coords[0] = coords[1];
|
coords[0] = coords[1];
|
||||||
|
}
|
||||||
|
|
||||||
const auto xOffset = areSameOffsets ? zOffset : shape::getOffset(getShapeInfo(), coords);
|
const auto xOffset = areSameOffsets ? zOffset : shape::getOffset(getShapeInfo(), coords);
|
||||||
z[zOffset] = x[xOffset];
|
z[zOffset] = x[xOffset];
|
||||||
|
|
||||||
|
if (xRank != zRank) // restore first coordinate
|
||||||
|
coords[0] = temp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -308,7 +315,7 @@ void NDArray::tile(const std::vector<Nd4jLong>& reps, NDArray& target) const {
|
||||||
// fill newBuff, loop through all elements of newBuff
|
// fill newBuff, loop through all elements of newBuff
|
||||||
// looping through _buffer goes automatically by means of getSubArrayIndex applying
|
// looping through _buffer goes automatically by means of getSubArrayIndex applying
|
||||||
const int ews = target.ews();
|
const int ews = target.ews();
|
||||||
const int targetLen = target.lengthOf();
|
const auto targetLen = target.lengthOf();
|
||||||
if(target.ordering() == 'c' && ews == 1) { // ews == 1 always here
|
if(target.ordering() == 'c' && ews == 1) { // ews == 1 always here
|
||||||
//#pragma omp parallel for simd if(targetLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)
|
//#pragma omp parallel for simd if(targetLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)
|
||||||
for(Nd4jLong i=0; i<targetLen; ++i) {
|
for(Nd4jLong i=0; i<targetLen; ++i) {
|
||||||
|
@ -372,16 +379,20 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
|
||||||
|
|
||||||
const int rank = input.rankOf(); // xRank = zRank
|
const int rank = input.rankOf(); // xRank = zRank
|
||||||
const int zLen = output.lengthOf(); // xLen <= zLen
|
const int zLen = output.lengthOf(); // xLen <= zLen
|
||||||
const int repSize = repeats.size();
|
const uint repSize = repeats.size();
|
||||||
|
|
||||||
// loop through input array
|
// loop through input array
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
Nd4jLong coords[MAX_RANK];
|
|
||||||
for (auto i = start; i < stop; i++) {
|
|
||||||
shape::index2coords(i, output.getShapeInfo(), coords);
|
|
||||||
|
|
||||||
|
int coords[MAX_RANK], temp;
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
|
||||||
|
shape::index2coordsCPU(start, i, output.getShapeInfo(), coords);
|
||||||
const auto zOffset = shape::getOffset(output.getShapeInfo(), coords);
|
const auto zOffset = shape::getOffset(output.getShapeInfo(), coords);
|
||||||
|
|
||||||
|
temp = coords[axis];
|
||||||
|
|
||||||
if (repSize > 1) {
|
if (repSize > 1) {
|
||||||
for (uint j = 0; j < repSize; ++j) {
|
for (uint j = 0; j < repSize; ++j) {
|
||||||
coords[axis] -= repeats[j];
|
coords[axis] -= repeats[j];
|
||||||
|
@ -394,6 +405,8 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
|
||||||
coords[axis] /= repeats[0];
|
coords[axis] /= repeats[0];
|
||||||
|
|
||||||
z[zOffset] = x[shape::getOffset(input.getShapeInfo(), coords)];
|
z[zOffset] = x[shape::getOffset(input.getShapeInfo(), coords)];
|
||||||
|
|
||||||
|
coords[axis] = temp;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -0,0 +1,148 @@
|
||||||
|
################################################################################
|
||||||
|
# Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
#
|
||||||
|
# This program and the accompanying materials are made available under the
|
||||||
|
# terms of the Apache License, Version 2.0 which is available at
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# License for the specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
#ifndef NDARRAY_MACRO
|
||||||
|
#define NDARRAY_MACRO
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
|
||||||
|
//NDArray<T> *other, T *extraParams
|
||||||
|
BUILD_CALL_1(template void NDArray<float>::template applyPairwiseTransform, float, (NDArray<float>* other, float* extraParams), PAIRWISE_TRANSFORM_OPS)
|
||||||
|
BUILD_CALL_1(template void NDArray<float16>::applyPairwiseTransform, float16, (NDArray<float16>* other, float16* extraParams), PAIRWISE_TRANSFORM_OPS)
|
||||||
|
BUILD_CALL_1(template void NDArray<double>::applyPairwiseTransform, double, (NDArray<double>* other, double* extraParams), PAIRWISE_TRANSFORM_OPS)
|
||||||
|
|
||||||
|
// NDArray<T> *other, NDArray<T> *target, T *extraParams
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float>::applyPairwiseTransform, float, (NDArray<float>* other, NDArray<float>* target, float* extraParams), PAIRWISE_TRANSFORM_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float16>::applyPairwiseTransform, float16, (NDArray<float16>* other, NDArray<float16>* target, float16* extraParams), PAIRWISE_TRANSFORM_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<double>::applyPairwiseTransform, double, (NDArray<double>* other, NDArray<double>* target, double* extraParams), PAIRWISE_TRANSFORM_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float16>::applyScalar, float16, (NDArray<float16>& scalar, NDArray<float16>* target, float16 *extraParams) const, SCALAR_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float16>::applyScalar, float16, (float16 scalar, NDArray<float16>* target, float16 *extraParams) const, SCALAR_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float>::applyScalar, float, (NDArray<float>& scalar, NDArray<float>* target, float *extraParams) const, SCALAR_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float>::applyScalar, float, (float scalar, NDArray<float>* target, float *extraParams) const, SCALAR_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<double>::applyScalar, double, (NDArray<double>& scalar, NDArray<double>* target, double *extraParams) const, SCALAR_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<double>::applyScalar, double, (double scalar, NDArray<double>* target, double *extraParams) const, SCALAR_OPS)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
BUILD_CALL_1(template float16 sd::NDArray<float16>::reduceNumber, float16, (float16 *extraParams) const, REDUCE_OPS)
|
||||||
|
BUILD_CALL_1(template float sd::NDArray<float>::reduceNumber, float, (float *extraParams) const, REDUCE_OPS)
|
||||||
|
BUILD_CALL_1(template double sd::NDArray<double>::reduceNumber, double, (double *extraParams) const, REDUCE_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template Nd4jLong sd::NDArray<float16>::indexReduceNumber, float16, (float16 *extraParams), INDEX_REDUCE_OPS)
|
||||||
|
BUILD_CALL_1(template Nd4jLong sd::NDArray<float>::indexReduceNumber, float, (float *extraParams), INDEX_REDUCE_OPS)
|
||||||
|
BUILD_CALL_1(template Nd4jLong sd::NDArray<double>::indexReduceNumber, double, (double *extraParams), INDEX_REDUCE_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float16>::applyBroadcast, float16, (std::initializer_list<int> list, const sd::NDArray<float16>* a, sd::NDArray<float16>* b, float16* c), BROADCAST_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float>::applyBroadcast, float, (std::initializer_list<int> list, const sd::NDArray<float>* a, sd::NDArray<float>* b, float* c), BROADCAST_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<double>::applyBroadcast, double, (std::initializer_list<int> list, const sd::NDArray<double>* a, sd::NDArray<double>* b, double* c), BROADCAST_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float16>::applyTrueBroadcast, float16,(const sd::NDArray<float16>* a, sd::NDArray<float16>* target, const bool checkTargetShape, float16* c) const, BROADCAST_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float>::applyTrueBroadcast, float, (const sd::NDArray<float>* a, sd::NDArray<float>* target, const bool checkTargetShape, float* c) const, BROADCAST_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<double>::applyTrueBroadcast, double, (const sd::NDArray<double>* a, sd::NDArray<double>* target, const bool checkTargetShape, double* c) const, BROADCAST_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template sd::NDArray<float16>* sd::NDArray<float16>::applyTrueBroadcast, float16, (const sd::NDArray<float16>* a, float16* c) const, BROADCAST_OPS)
|
||||||
|
BUILD_CALL_1(template sd::NDArray<float>* sd::NDArray<float>::applyTrueBroadcast, float, (const sd::NDArray<float>* a, float* c) const, BROADCAST_OPS)
|
||||||
|
BUILD_CALL_1(template sd::NDArray<double>* sd::NDArray<double>::applyTrueBroadcast, double, (const sd::NDArray<double>* a, double* c) const, BROADCAST_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template sd::NDArray<float16> sd::NDArray<float16>::applyTrueBroadcast, float16, (const sd::NDArray<float16>& a, float16* c) const, BROADCAST_OPS)
|
||||||
|
BUILD_CALL_1(template sd::NDArray<float> sd::NDArray<float>::applyTrueBroadcast, float, (const sd::NDArray<float>& a, float* c) const, BROADCAST_OPS)
|
||||||
|
BUILD_CALL_1(template sd::NDArray<double> sd::NDArray<double>::applyTrueBroadcast, double, (const sd::NDArray<double>& a, double* c) const, BROADCAST_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float16>::applyTransform, float16, (NDArray<float16>* target, float16* extraParams), TRANSFORM_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float>::applyTransform, float, (NDArray<float>* target, float* extraParams), TRANSFORM_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<double>::applyTransform, double, (NDArray<double>* target, double* extraParams), TRANSFORM_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float16>::applyTransform, float16, (float16* extraParams), TRANSFORM_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float>::applyTransform, float, (float* extraParams), TRANSFORM_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<double>::applyTransform, double, (double* extraParams), TRANSFORM_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float16>::applyRandom, float16, (sd::random::RandomBuffer *buffer, NDArray<float16>* y, NDArray<float16>* z, float16* extraParams), RANDOM_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float>::applyRandom, float, (sd::random::RandomBuffer *buffer, NDArray<float>* y, NDArray<float>* z, float* extraParams), RANDOM_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<double>::applyRandom, double, (sd::random::RandomBuffer *buffer, NDArray<double>* y, NDArray<double>* z, double* extraParams), RANDOM_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template NDArray<float16> sd::NDArray<float16>::transform, float16, (float16* extraParams) const, TRANSFORM_OPS)
|
||||||
|
BUILD_CALL_1(template NDArray<float> sd::NDArray<float>::transform, float, (float* extraParams) const, TRANSFORM_OPS)
|
||||||
|
BUILD_CALL_1(template NDArray<double> sd::NDArray<double>::transform, double, (double* extraParams) const, TRANSFORM_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template NDArray<float> *sd::NDArray<float>::template reduceAlongDimension, float, (const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS)
|
||||||
|
BUILD_CALL_1(template NDArray<float16> *sd::NDArray<float16>::template reduceAlongDimension, float16, (const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS)
|
||||||
|
BUILD_CALL_1(template NDArray<double> *sd::NDArray<double>::template reduceAlongDimension, double, (const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template NDArray<float> sd::NDArray<float>::template reduceAlongDims, float, (const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS)
|
||||||
|
BUILD_CALL_1(template NDArray<float16> sd::NDArray<float16>::template reduceAlongDims, float16, (const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS)
|
||||||
|
BUILD_CALL_1(template NDArray<double> sd::NDArray<double>::template reduceAlongDims, double, (const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template NDArray<float> *sd::NDArray<float>::template reduceAlongDimension, float, (const std::initializer_list<int>& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS)
|
||||||
|
BUILD_CALL_1(template NDArray<float16> *sd::NDArray<float16>::template reduceAlongDimension, float16, (const std::initializer_list<int>& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS)
|
||||||
|
BUILD_CALL_1(template NDArray<double> *sd::NDArray<double>::template reduceAlongDimension, double, (const std::initializer_list<int>& dimensions, const bool keepDims, const bool supportOldShapes) const, REDUCE_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float>::template reduceAlongDimension, float, (NDArray<float>* target, const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes, float * extras) const, REDUCE_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float16>::template reduceAlongDimension, float16, (NDArray<float16>* target, const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes, float16 * extras) const, REDUCE_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<double>::template reduceAlongDimension, double, (NDArray<double>* target, const std::vector<int>& dimension, const bool keepDims, const bool supportOldShapes, double * extras) const, REDUCE_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template NDArray<float> *sd::NDArray<float>::template varianceAlongDimension, float, (const bool biasCorrected, const std::initializer_list<int>& dimensions) const, SUMMARY_STATS_OPS)
|
||||||
|
BUILD_CALL_1(template NDArray<float16> *sd::NDArray<float16>::template varianceAlongDimension, float16, (const bool biasCorrected, const std::initializer_list<int>& dimensions) const, SUMMARY_STATS_OPS)
|
||||||
|
BUILD_CALL_1(template NDArray<double> *sd::NDArray<double>::template varianceAlongDimension, double, (const bool biasCorrected, const std::initializer_list<int>& dimensions) const, SUMMARY_STATS_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float>::template varianceAlongDimension, float, (const NDArray<float> *target, const bool biasCorrected, const std::initializer_list<int>& dimensions), SUMMARY_STATS_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float16>::template varianceAlongDimension, float16, (const NDArray<float16> *target,const bool biasCorrected, const std::initializer_list<int>& dimensions), SUMMARY_STATS_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<double>::template varianceAlongDimension, double, (const NDArray<double> *target, const bool biasCorrected, const std::initializer_list<int>& dimensions), SUMMARY_STATS_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float>::template varianceAlongDimension, float, (const NDArray<float> *target, const bool biasCorrected, const std::vector<int>& dimensions), SUMMARY_STATS_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float16>::template varianceAlongDimension, float16, (const NDArray<float16> *target,const bool biasCorrected, const std::vector<int>& dimensions), SUMMARY_STATS_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<double>::template varianceAlongDimension, double, (const NDArray<double> *target, const bool biasCorrected, const std::vector<int>& dimensions), SUMMARY_STATS_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template float sd::NDArray<float>::template varianceNumber, float, (bool biasCorrected), SUMMARY_STATS_OPS)
|
||||||
|
BUILD_CALL_1(template float16 sd::NDArray<float16>::template varianceNumber, float16, (bool biasCorrected), SUMMARY_STATS_OPS)
|
||||||
|
BUILD_CALL_1(template double sd::NDArray<double>::template varianceNumber, double, (bool biasCorrected), SUMMARY_STATS_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template NDArray<float> *sd::NDArray<float>::template applyReduce3, float, (const NDArray<float>* other, const float* extraParams) const, REDUCE3_OPS)
|
||||||
|
BUILD_CALL_1(template NDArray<float16> *sd::NDArray<float16>::template applyReduce3, float16, (const NDArray<float16>* other, const float16* extraParams) const, REDUCE3_OPS)
|
||||||
|
BUILD_CALL_1(template NDArray<double> *sd::NDArray<double>::template applyReduce3, double, (const NDArray<double>* other, const double* extraParams) const, REDUCE3_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template NDArray<float> *sd::NDArray<float>::template applyReduce3, float, (const NDArray<float>* other, const std::vector<int> &dims, const float* extraParams) const, REDUCE3_OPS)
|
||||||
|
BUILD_CALL_1(template NDArray<float16> *sd::NDArray<float16>::template applyReduce3, float16, (const NDArray<float16>* other, const std::vector<int> &dims, const float16* extraParams) const, REDUCE3_OPS)
|
||||||
|
BUILD_CALL_1(template NDArray<double> *sd::NDArray<double>::template applyReduce3, double, (const NDArray<double>* other, const std::vector<int> &dims, const double* extraParams) const, REDUCE3_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float>::template applyIndexReduce, float, (const NDArray<float>* target, const std::vector<int> & alpha, const float* beta) const, INDEX_REDUCE_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<float16>::template applyIndexReduce, float16, (const NDArray<float16>* target, const std::vector<int> & alpha, const float16* beta) const, INDEX_REDUCE_OPS)
|
||||||
|
BUILD_CALL_1(template void sd::NDArray<double>::template applyIndexReduce, double, (const NDArray<double>* target, const std::vector<int> & alpha, const double* beta) const, INDEX_REDUCE_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template NDArray<float> *sd::NDArray<float>::template applyIndexReduce, float, (const std::vector<int> & alpha, const float* beta) const, INDEX_REDUCE_OPS)
|
||||||
|
BUILD_CALL_1(template NDArray<float16> *sd::NDArray<float16>::template applyIndexReduce, float16, (const std::vector<int> & alpha, const float16* beta) const, INDEX_REDUCE_OPS)
|
||||||
|
BUILD_CALL_1(template NDArray<double> *sd::NDArray<double>::template applyIndexReduce, double, (const std::vector<int> & alpha, const double* beta) const, INDEX_REDUCE_OPS)
|
||||||
|
|
||||||
|
BUILD_CALL_1(template NDArray<float> *sd::NDArray<float>::template applyAllReduce3, float, (const sd::NDArray<float>* alpha, const std::vector<int> & beta, float const* gamma) const, REDUCE3_OPS)
|
||||||
|
BUILD_CALL_1(template NDArray<float16> *sd::NDArray<float16>::template applyAllReduce3, float16, (const sd::NDArray<float16>* alpha, const std::vector<int> & beta, float16 const* gamma) const, REDUCE3_OPS)
|
||||||
|
BUILD_CALL_1(template NDArray<double> *sd::NDArray<double>::template applyAllReduce3, double, (const sd::NDArray<double>* alpha, const std::vector<int> & beta, double const* gamma) const, REDUCE3_OPS)
|
||||||
|
|
||||||
|
template NDArray<float> mmul(const NDArray<float>& left, const NDArray<float>& right);
|
||||||
|
template NDArray<float16> mmul(const NDArray<float16>& left, const NDArray<float16>& right);
|
||||||
|
template NDArray<double> mmul(const NDArray<double>& left, const NDArray<double>& right);
|
||||||
|
|
||||||
|
// template NDArray<float> operator-(const float, const NDArray<float>&);
|
||||||
|
// template NDArray<float16> operator-(const float16, const NDArray<float16>&);
|
||||||
|
// template NDArray<double> operator-(const double, const NDArray<double>&);
|
||||||
|
|
||||||
|
// template NDArray<float> operator+(const float, const NDArray<float>&);
|
||||||
|
// template NDArray<float16> operator+(const float16, const NDArray<float16>&);
|
||||||
|
// template NDArray<double> operator+(const double, const NDArray<double>&);
|
||||||
|
|
||||||
|
|
||||||
|
#endif
|
|
@ -20,14 +20,14 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "../DataBuffer.h"
|
#include "../DataBuffer.h"
|
||||||
#include <DataTypeUtils.h>
|
#include <array/DataTypeUtils.h>
|
||||||
#include <op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
#include <exceptions/cuda_exception.h>
|
#include <exceptions/cuda_exception.h>
|
||||||
#include <execution/AffinityManager.h>
|
#include <execution/AffinityManager.h>
|
||||||
#include <memory/MemoryCounter.h>
|
#include <memory/MemoryCounter.h>
|
||||||
#include <exceptions/allocation_exception.h>
|
#include <exceptions/allocation_exception.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
void DataBuffer::expand(const uint64_t size) {
|
void DataBuffer::expand(const uint64_t size) {
|
||||||
if (size > _lenInBytes) {
|
if (size > _lenInBytes) {
|
||||||
// allocate new buffer
|
// allocate new buffer
|
||||||
|
@ -67,19 +67,19 @@ namespace nd4j {
|
||||||
void DataBuffer::allocateSpecial() {
|
void DataBuffer::allocateSpecial() {
|
||||||
|
|
||||||
if (_specialBuffer == nullptr && getLenInBytes() > 0) {
|
if (_specialBuffer == nullptr && getLenInBytes() > 0) {
|
||||||
auto deviceId = nd4j::AffinityManager::currentDeviceId();
|
auto deviceId = sd::AffinityManager::currentDeviceId();
|
||||||
|
|
||||||
if (_workspace == nullptr)
|
if (_workspace == nullptr)
|
||||||
if (!nd4j::memory::MemoryCounter::getInstance()->validate(getLenInBytes()))
|
if (!sd::memory::MemoryCounter::getInstance()->validate(getLenInBytes()))
|
||||||
throw nd4j::allocation_exception::build("Requested amount exceeds device limits", nd4j::memory::MemoryCounter::getInstance()->deviceLimit(deviceId), getLenInBytes());
|
throw sd::allocation_exception::build("Requested amount exceeds device limits", sd::memory::MemoryCounter::getInstance()->deviceLimit(deviceId), getLenInBytes());
|
||||||
|
|
||||||
|
|
||||||
ALLOCATE_SPECIAL(_specialBuffer, _workspace, getLenInBytes(), int8_t);
|
ALLOCATE_SPECIAL(_specialBuffer, _workspace, getLenInBytes(), int8_t);
|
||||||
_isOwnerSpecial = true;
|
_isOwnerSpecial = true;
|
||||||
|
|
||||||
if (_workspace == nullptr) {
|
if (_workspace == nullptr) {
|
||||||
nd4j::memory::MemoryCounter::getInstance()->countIn(deviceId, getLenInBytes());
|
sd::memory::MemoryCounter::getInstance()->countIn(deviceId, getLenInBytes());
|
||||||
nd4j::memory::MemoryCounter::getInstance()->countIn(nd4j::memory::MemoryType::DEVICE, getLenInBytes());
|
sd::memory::MemoryCounter::getInstance()->countIn(sd::memory::MemoryType::DEVICE, getLenInBytes());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -135,8 +135,8 @@ void DataBuffer::deleteSpecial() {
|
||||||
|
|
||||||
// count out towards DataBuffer device, only if we're not in workspace
|
// count out towards DataBuffer device, only if we're not in workspace
|
||||||
if (_workspace == nullptr) {
|
if (_workspace == nullptr) {
|
||||||
nd4j::memory::MemoryCounter::getInstance()->countOut(_deviceId, getLenInBytes());
|
sd::memory::MemoryCounter::getInstance()->countOut(_deviceId, getLenInBytes());
|
||||||
nd4j::memory::MemoryCounter::getInstance()->countOut(nd4j::memory::MemoryType::DEVICE, getLenInBytes());
|
sd::memory::MemoryCounter::getInstance()->countOut(sd::memory::MemoryType::DEVICE, getLenInBytes());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,14 +17,14 @@
|
||||||
#ifndef NDARRAY_CPP
|
#ifndef NDARRAY_CPP
|
||||||
#define NDARRAY_CPP
|
#define NDARRAY_CPP
|
||||||
|
|
||||||
#include "../NDArray.h"
|
#include <array/NDArray.h>
|
||||||
#include "../NDArrayFactory.h"
|
#include <array/NDArrayFactory.h>
|
||||||
#include "NativeOpExecutioner.h"
|
#include <legacy/NativeOpExecutioner.h>
|
||||||
#include <memory/Workspace.h>
|
#include <memory/Workspace.h>
|
||||||
#include <memory/MemoryRegistrator.h>
|
#include <memory/MemoryRegistrator.h>
|
||||||
#include <ops.h>
|
#include <ops/ops.h>
|
||||||
#include <ops/gemm.h>
|
#include <ops/gemm.h>
|
||||||
#include <pointercast.h>
|
#include <system/pointercast.h>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <helpers/logger.h>
|
#include <helpers/logger.h>
|
||||||
|
@ -37,17 +37,17 @@
|
||||||
#include <helpers/ShapeUtils.h>
|
#include <helpers/ShapeUtils.h>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <helpers/ArrayUtils.h>
|
#include <helpers/ArrayUtils.h>
|
||||||
#include <MmulHelper.h>
|
#include <helpers/MmulHelper.h>
|
||||||
#include <helpers/threshold.h>
|
#include <helpers/threshold.h>
|
||||||
#include <exceptions/datatype_exception.h>
|
#include <exceptions/datatype_exception.h>
|
||||||
#include <exceptions/cuda_exception.h>
|
#include <exceptions/cuda_exception.h>
|
||||||
#include <specials_cuda.h>
|
#include <ops/specials_cuda.h>
|
||||||
#include <loops/special_kernels.h>
|
#include <loops/special_kernels.h>
|
||||||
#include <PointersManager.h>
|
#include <helpers/PointersManager.h>
|
||||||
#include "../NDArray.hpp"
|
#include <array/NDArray.hXX>
|
||||||
#include <ConstantShapeHelper.h>
|
#include <helpers/ConstantShapeHelper.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
|
|
||||||
void* NDArray::platformBuffer() { return specialBuffer(); }
|
void* NDArray::platformBuffer() { return specialBuffer(); }
|
||||||
void* NDArray::getPlatformBuffer() const { return getSpecialBuffer(); }
|
void* NDArray::getPlatformBuffer() const { return getSpecialBuffer(); }
|
||||||
|
@ -85,12 +85,12 @@ __global__ static void fillAsTriangularCuda(const void* vx, const Nd4jLong* xSha
|
||||||
const auto x = reinterpret_cast<const T*>(vx);
|
const auto x = reinterpret_cast<const T*>(vx);
|
||||||
auto z = reinterpret_cast<T*>(vz);
|
auto z = reinterpret_cast<T*>(vz);
|
||||||
|
|
||||||
__shared__ int zRank, xRank, areSameOffsets; // xRank == zRank always, except when xRank = 1, in this case zRank = 2
|
__shared__ int zRank, xRank, areSameOffsets, *sharedMem; // xRank == zRank always, except when xRank = 1, in this case zRank = 2
|
||||||
__shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
|
__shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
extern __shared__ unsigned char shmem[];
|
extern __shared__ unsigned char shmem[];
|
||||||
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
sharedMem = reinterpret_cast<int*>(shmem);
|
||||||
areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
xRank = shape::rank(xShapeInfo);
|
xRank = shape::rank(xShapeInfo);
|
||||||
zRank = shape::rank(zShapeInfo);
|
zRank = shape::rank(zShapeInfo);
|
||||||
|
@ -137,7 +137,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& t
|
||||||
|
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
const int blocksPerGrid = (target.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
const int blocksPerGrid = (target.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
const int sharedMem = threadsPerBlock * sizeof(decltype(*target.getShapeInfo())) * target.rankOf() + 128;
|
const int sharedMem = threadsPerBlock * sizeof(int) * target.rankOf() + 128;
|
||||||
|
|
||||||
PointersManager manager(getContext(), "NDArray::fillAsTriangular");
|
PointersManager manager(getContext(), "NDArray::fillAsTriangular");
|
||||||
|
|
||||||
|
@ -155,12 +155,12 @@ __global__ static void identityMatrixCuda(void* vx, const Nd4jLong* xShapeInfo,
|
||||||
|
|
||||||
auto x = reinterpret_cast<T*>(vx);
|
auto x = reinterpret_cast<T*>(vx);
|
||||||
|
|
||||||
__shared__ int rank;
|
__shared__ int rank, *sharedMem;
|
||||||
__shared__ Nd4jLong len, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
|
__shared__ Nd4jLong len, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
extern __shared__ unsigned char shmem[];
|
extern __shared__ unsigned char shmem[];
|
||||||
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
sharedMem = reinterpret_cast<int*>(shmem);
|
||||||
rank = shape::rank(xShapeInfo);
|
rank = shape::rank(xShapeInfo);
|
||||||
len = shape::length(xShapeInfo);
|
len = shape::length(xShapeInfo);
|
||||||
totalThreads = gridDim.x * blockDim.x;
|
totalThreads = gridDim.x * blockDim.x;
|
||||||
|
@ -201,7 +201,7 @@ void NDArray::setIdentity() {
|
||||||
|
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
const int blocksPerGrid = (lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
const int blocksPerGrid = (lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
const int sharedMem = threadsPerBlock * sizeof(decltype(getShapeInfo())) * rankOf() + 128;
|
const int sharedMem = threadsPerBlock * sizeof(int) * rankOf() + 128;
|
||||||
|
|
||||||
PointersManager manager(getContext(), "NDArray::setIdentity");
|
PointersManager manager(getContext(), "NDArray::setIdentity");
|
||||||
|
|
||||||
|
@ -398,13 +398,13 @@ __global__ static void repeatCuda(const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
const X* x = reinterpret_cast<const X*>(vx);
|
const X* x = reinterpret_cast<const X*>(vx);
|
||||||
Z* z = reinterpret_cast<Z*>(vz);
|
Z* z = reinterpret_cast<Z*>(vz);
|
||||||
|
|
||||||
__shared__ int rank;
|
__shared__ int rank, *sharedMem;
|
||||||
__shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen = zLen
|
__shared__ Nd4jLong zLen, totalThreads; // xLen = zLen
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
extern __shared__ unsigned char shmem[];
|
extern __shared__ unsigned char shmem[];
|
||||||
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
sharedMem = reinterpret_cast<int*>(shmem);
|
||||||
|
|
||||||
rank = shape::rank(zShapeInfo); // xRank = zRank
|
rank = shape::rank(zShapeInfo); // xRank = zRank
|
||||||
zLen = shape::length(zShapeInfo); // xLen <= zLen
|
zLen = shape::length(zShapeInfo); // xLen <= zLen
|
||||||
|
@ -460,7 +460,7 @@ NDArray NDArray::repeat(const int axis, const std::vector<int>& repeats) const {
|
||||||
|
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
const int sharedMem = output.rankOf() * sizeof(int) * threadsPerBlock + 128;
|
||||||
|
|
||||||
PointersManager manager(getContext(), "NDArray::repeat(const int axis, const std::vector<int>& repeats)");
|
PointersManager manager(getContext(), "NDArray::repeat(const int axis, const std::vector<int>& repeats)");
|
||||||
|
|
||||||
|
@ -484,7 +484,7 @@ void NDArray::repeat(const int axis, const std::vector<int>& repeats, NDArray& t
|
||||||
|
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
const int blocksPerGrid = (target.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
const int blocksPerGrid = (target.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
const int sharedMem = target.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
const int sharedMem = target.rankOf() * sizeof(int) * threadsPerBlock + 128;
|
||||||
|
|
||||||
PointersManager manager(getContext(), "NDArray::repeat(const int axis, const std::vector<int>& repeats)");
|
PointersManager manager(getContext(), "NDArray::repeat(const int axis, const std::vector<int>& repeats)");
|
||||||
|
|
||||||
|
@ -569,6 +569,6 @@ template void NDArray::printCurrentBuffer<double>(const bool host, const char* m
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
} // end namespace nd4j
|
} // end namespace sd
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
#include <array/ByteOrderUtils.h>
|
#include <array/ByteOrderUtils.h>
|
||||||
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
ByteOrder ByteOrderUtils::fromFlatByteOrder(nd4j::graph::ByteOrder order) {
|
ByteOrder ByteOrderUtils::fromFlatByteOrder(sd::graph::ByteOrder order) {
|
||||||
return (ByteOrder) order;
|
return (ByteOrder) order;
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -20,7 +20,7 @@
|
||||||
|
|
||||||
#include "../ConstantDataBuffer.h"
|
#include "../ConstantDataBuffer.h"
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
ConstantDataBuffer::ConstantDataBuffer(Nd4jPointer primary, Nd4jPointer special, Nd4jLong numEelements, Nd4jLong sizeOf) {
|
ConstantDataBuffer::ConstantDataBuffer(Nd4jPointer primary, Nd4jPointer special, Nd4jLong numEelements, Nd4jLong sizeOf) {
|
||||||
_primaryBuffer = primary;
|
_primaryBuffer = primary;
|
||||||
_specialBuffer = special;
|
_specialBuffer = special;
|
||||||
|
|
|
@ -19,10 +19,10 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <array/ConstantDescriptor.h>
|
#include <array/ConstantDescriptor.h>
|
||||||
#include <DataTypeUtils.h>
|
#include <array/DataTypeUtils.h>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace sd {
|
||||||
ConstantDescriptor::ConstantDescriptor(double* values, int length) {
|
ConstantDescriptor::ConstantDescriptor(double* values, int length) {
|
||||||
for (int e = 0; e < length; e++)
|
for (int e = 0; e < length; e++)
|
||||||
_floatValues.emplace_back(values[e]);
|
_floatValues.emplace_back(values[e]);
|
||||||
|
@ -75,3 +75,25 @@ namespace nd4j {
|
||||||
return isInteger() ? _integerValues.size() : isFloat() ? _floatValues.size() : 0L;
|
return isInteger() ? _integerValues.size() : isFloat() ? _floatValues.size() : 0L;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace std {
|
||||||
|
size_t hash<sd::ConstantDescriptor>::operator()(const sd::ConstantDescriptor &k) const {
|
||||||
|
using std::hash;
|
||||||
|
// Compute individual hash values for first,
|
||||||
|
// second and third and combine them using XOR
|
||||||
|
// and bit shifting:
|
||||||
|
size_t hashVal = 0;
|
||||||
|
size_t i = 0;
|
||||||
|
if (k.isInteger()) {
|
||||||
|
for (auto v: k.integerValues()) {
|
||||||
|
hashVal ^= std::hash<Nd4jLong>()(v) + 0x9e3779b9 + (hashVal << 6) + (hashVal >> 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (auto v: k.floatValues()) {
|
||||||
|
hashVal ^= std::hash<double>()(v) + 0x9e3779b9 + (hashVal << 6) + (hashVal >> 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return hashVal;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue