parent
e51e6ebfd2
commit
e7730eded4
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -17,7 +17,6 @@
|
||||||
package org.nd4j.list;
|
package org.nd4j.list;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -35,35 +34,15 @@ import java.util.*;
|
||||||
*
|
*
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
|
@SuppressWarnings("unchecked") //too many of them.
|
||||||
public abstract class BaseNDArrayList<X extends Number> extends AbstractList<X> {
|
public abstract class BaseNDArrayList<X extends Number> extends AbstractList<X> {
|
||||||
protected INDArray container;
|
protected INDArray container;
|
||||||
protected int size;
|
protected int size;
|
||||||
|
|
||||||
|
BaseNDArrayList() {
|
||||||
|
|
||||||
public BaseNDArrayList() {
|
|
||||||
this.container = Nd4j.create(10);
|
this.container = Nd4j.create(10);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Specify the underlying ndarray for this list.
|
|
||||||
* @param container the underlying array.
|
|
||||||
*/
|
|
||||||
public BaseNDArrayList(INDArray container) {
|
|
||||||
this.container = container;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Allocates the container and this list with
|
|
||||||
* the given size
|
|
||||||
* @param size the size to allocate with
|
|
||||||
*/
|
|
||||||
public void allocateWithSize(int size) {
|
|
||||||
container = Nd4j.create(1,size);
|
|
||||||
this.size = size;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get a view of the underlying array
|
* Get a view of the underlying array
|
||||||
* relative to the size of the actual array.
|
* relative to the size of the actual array.
|
||||||
|
@ -321,11 +300,11 @@ public abstract class BaseNDArrayList<X extends Number> extends AbstractList<X
|
||||||
private class NDArrayListIterator implements ListIterator<X> {
|
private class NDArrayListIterator implements ListIterator<X> {
|
||||||
private int curr = 0;
|
private int curr = 0;
|
||||||
|
|
||||||
public NDArrayListIterator(int curr) {
|
NDArrayListIterator(int curr) {
|
||||||
this.curr = curr;
|
this.curr = curr;
|
||||||
}
|
}
|
||||||
|
|
||||||
public NDArrayListIterator() {
|
NDArrayListIterator() {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -335,9 +314,9 @@ public abstract class BaseNDArrayList<X extends Number> extends AbstractList<X
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public X next() {
|
public X next() {
|
||||||
Number ret = get(curr);
|
X ret = get(curr);
|
||||||
curr++;
|
curr++;
|
||||||
return (X) ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -347,9 +326,9 @@ public abstract class BaseNDArrayList<X extends Number> extends AbstractList<X
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public X previous() {
|
public X previous() {
|
||||||
Number ret = get(curr - 1);
|
X ret = get(curr - 1);
|
||||||
curr--;
|
curr--;
|
||||||
return (X) ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1,39 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.list;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* An {@link BaseNDArrayList} for float
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class FloatNDArrayList extends BaseNDArrayList<Float> {
|
|
||||||
public FloatNDArrayList() {
|
|
||||||
}
|
|
||||||
|
|
||||||
public FloatNDArrayList(INDArray container) {
|
|
||||||
super(container);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Float get(int i) {
|
|
||||||
Number ret = container.getDouble(i);
|
|
||||||
return ret.floatValue();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,42 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.list;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* An {@link BaseNDArrayList} for integers
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class IntNDArrayList extends BaseNDArrayList<Integer> {
|
|
||||||
public IntNDArrayList() {
|
|
||||||
}
|
|
||||||
|
|
||||||
public IntNDArrayList(INDArray container) {
|
|
||||||
super(container);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Integer get(int i) {
|
|
||||||
Number ret = container.getDouble(i);
|
|
||||||
return ret.intValue();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -17,7 +17,6 @@
|
||||||
package org.nd4j.list;
|
package org.nd4j.list;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.val;
|
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
@ -272,11 +271,11 @@ public class NDArrayList extends BaseNDArrayList<Double> {
|
||||||
private class NDArrayListIterator implements ListIterator<Double> {
|
private class NDArrayListIterator implements ListIterator<Double> {
|
||||||
private int curr = 0;
|
private int curr = 0;
|
||||||
|
|
||||||
public NDArrayListIterator(int curr) {
|
NDArrayListIterator(int curr) {
|
||||||
this.curr = curr;
|
this.curr = curr;
|
||||||
}
|
}
|
||||||
|
|
||||||
public NDArrayListIterator() {
|
NDArrayListIterator() {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1,31 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.list.matrix;
|
|
||||||
|
|
||||||
import org.nd4j.list.FloatNDArrayList;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A {@link MatrixBaseNDArrayList}
|
|
||||||
* for float data type
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class FloatMatrixNDArrayList extends MatrixBaseNDArrayList<FloatNDArrayList> {
|
|
||||||
public FloatMatrixNDArrayList() {
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,31 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.list.matrix;
|
|
||||||
|
|
||||||
import org.nd4j.list.IntNDArrayList;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A {@link MatrixBaseNDArrayList}
|
|
||||||
* for int data type
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class IntMatrixNDArrayList extends MatrixBaseNDArrayList<IntNDArrayList> {
|
|
||||||
public IntMatrixNDArrayList() {
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,184 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.list.matrix;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.list.BaseNDArrayList;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* An {@link ArrayList} like implementation of {@link List}
|
|
||||||
* using {@link INDArray} as the backing data structure.
|
|
||||||
*
|
|
||||||
* Creates an internal container of ndarray lists with a matrix shape.
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public abstract class MatrixBaseNDArrayList<X extends BaseNDArrayList> extends AbstractList<X> {
|
|
||||||
private List<X> list = new ArrayList<>();
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get a view of the underlying array
|
|
||||||
* relative to the size of the actual array.
|
|
||||||
* (Sometimes there are overflows in the internals
|
|
||||||
* but you want to use the internal INDArray for computing something
|
|
||||||
* directly, this gives you the relevant subset that reflects the content of the list)
|
|
||||||
* @return the view of the underlying ndarray relative to the collection's real size
|
|
||||||
*/
|
|
||||||
public INDArray array() {
|
|
||||||
List<INDArray> retList = new ArrayList<>(list.size());
|
|
||||||
for(X x : list) {
|
|
||||||
INDArray arr = x.array();
|
|
||||||
retList.add(arr.reshape(1, arr.length()));
|
|
||||||
}
|
|
||||||
|
|
||||||
return Nd4j.concat(0,retList.toArray(new INDArray[retList.size()]));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int size() {
|
|
||||||
return list.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isEmpty() {
|
|
||||||
return list.isEmpty();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean contains(Object o) {
|
|
||||||
return list.contains(o);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Iterator<X> iterator() {
|
|
||||||
return list.iterator();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Object[] toArray() {
|
|
||||||
return list.toArray();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public <T> T[] toArray(T[] ts) {
|
|
||||||
return list.toArray(ts);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean add(X aX) {
|
|
||||||
return list.add(aX);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean remove(Object o) {
|
|
||||||
return list.remove(o);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean containsAll(Collection<?> collection) {
|
|
||||||
return list.containsAll(collection);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean addAll(Collection<? extends X> collection) {
|
|
||||||
return list.addAll(collection);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean addAll(int i, Collection<? extends X> collection) {
|
|
||||||
return list.addAll(collection);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean removeAll(Collection<?> collection) {
|
|
||||||
return list.removeAll(collection);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean retainAll(Collection<?> collection) {
|
|
||||||
return list.retainAll(collection);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void clear() {
|
|
||||||
list.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public X get(int i) {
|
|
||||||
return list.get(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public X set(int i, X aX) {
|
|
||||||
return list.set(i,aX);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void add(int i, X aX) {
|
|
||||||
list.add(i,aX);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public X remove(int i) {
|
|
||||||
return list.remove(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int indexOf(Object o) {
|
|
||||||
return list.indexOf(o);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int lastIndexOf(Object o) {
|
|
||||||
return list.lastIndexOf(o);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ListIterator<X> listIterator() {
|
|
||||||
return list.listIterator();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ListIterator<X> listIterator(int i) {
|
|
||||||
return list.listIterator(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return list.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get entry i,j in the matrix
|
|
||||||
* @param i the row
|
|
||||||
* @param j the column
|
|
||||||
* @return the entry at i,j if it exists
|
|
||||||
*/
|
|
||||||
public Number getEntry(int i,int j) {
|
|
||||||
return list.get(i).get(j);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,32 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.list.matrix;
|
|
||||||
|
|
||||||
import org.nd4j.list.NDArrayList;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A {@link MatrixBaseNDArrayList}
|
|
||||||
* for double data type
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class MatrixNDArrayList extends MatrixBaseNDArrayList<NDArrayList> {
|
|
||||||
public MatrixNDArrayList() {
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -18,15 +18,12 @@ package org.nd4j.list;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.list.matrix.MatrixNDArrayList;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertFalse;
|
|
||||||
|
|
||||||
public class NDArrayListTest extends BaseNd4jTest {
|
public class NDArrayListTest extends BaseNd4jTest {
|
||||||
|
|
||||||
|
@ -71,24 +68,4 @@ public class NDArrayListTest extends BaseNd4jTest {
|
||||||
assertEquals(ndArrayList.size(),ndArrayList.array().length());
|
assertEquals(ndArrayList.size(),ndArrayList.array().length());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testMatrixList() {
|
|
||||||
MatrixNDArrayList matrixNDArrayList = new MatrixNDArrayList();
|
|
||||||
for(int i = 0; i < 5; i++) {
|
|
||||||
NDArrayList ndArrayList = new NDArrayList();
|
|
||||||
for(int j = 0; j < 4; j++) {
|
|
||||||
ndArrayList.add((double) j);
|
|
||||||
}
|
|
||||||
|
|
||||||
matrixNDArrayList.add(ndArrayList);
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray arr = matrixNDArrayList.array();
|
|
||||||
assertEquals(5,arr.rows());
|
|
||||||
assertFalse(matrixNDArrayList.isEmpty());
|
|
||||||
assertEquals(0.0,matrixNDArrayList.getEntry(0,0));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue