Update parameterized python tests, move python tests to proper package

master
agibsonccc 2021-03-18 12:18:39 +09:00
parent 13cae7fb60
commit 224f18a586
6 changed files with 157 additions and 53 deletions

View File

@ -1,31 +1,31 @@
/* /*
* ****************************************************************************** *
* * * * ******************************************************************************
* * * * *
* * This program and the accompanying materials are made available under the * * *
* * terms of the Apache License, Version 2.0 which is available at * * * This program and the accompanying materials are made available under the
* * https://www.apache.org/licenses/LICENSE-2.0. * * * terms of the Apache License, Version 2.0 which is available at
* * * * * https://www.apache.org/licenses/LICENSE-2.0.
* * See the NOTICE file distributed with this work for additional * * *
* * information regarding copyright ownership. * * * See the NOTICE file distributed with this work for additional
* * Unless required by applicable law or agreed to in writing, software * * * information regarding copyright ownership.
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * * * Unless required by applicable law or agreed to in writing, software
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * License for the specific language governing permissions and limitations * * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * under the License. * * * License for the specific language governing permissions and limitations
* * * * * under the License.
* * SPDX-License-Identifier: Apache-2.0 * * *
* ***************************************************************************** * * * SPDX-License-Identifier: Apache-2.0
* * *****************************************************************************
*
*
*/ */
package org.nd4j.python4j;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.python4j.*;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -35,7 +35,6 @@ import javax.annotation.concurrent.NotThreadSafe;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -43,6 +42,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
@NotThreadSafe @NotThreadSafe
public class PythonNumpyBasicTest { public class PythonNumpyBasicTest {
public static Stream<Arguments> params() { public static Stream<Arguments> params() {
DataType[] types = new DataType[] { DataType[] types = new DataType[] {
DataType.BOOL, DataType.BOOL,
@ -61,9 +61,9 @@ public class PythonNumpyBasicTest {
}; };
long[][] shapes = new long[][]{ long[][] shapes = new long[][]{
new long[]{2, 3}, new long[]{2, 3},
new long[]{3}, new long[]{3},
new long[]{1}, new long[]{1},
new long[]{} // scalar new long[]{} // scalar
}; };
@ -78,23 +78,23 @@ public class PythonNumpyBasicTest {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.nd4j.python4j.PythonNumpyBasicTest#params")
public void testConversion(DataType dataType,long[] shape){ public void testConversion(DataType dataType,long[] shape) {
try(PythonGIL pythonGIL = PythonGIL.lock()) { try(PythonGIL pythonGIL = PythonGIL.lock()) {
INDArray arr = Nd4j.zeros(dataType, shape); INDArray arr = Nd4j.zeros(dataType, shape);
PythonObject npArr = PythonTypes.convert(arr); PythonObject npArr = PythonTypes.convert(arr);
INDArray arr2 = PythonTypes.<INDArray>getPythonTypeForPythonObject(npArr).toJava(npArr); INDArray arr2 = PythonTypes.<INDArray>getPythonTypeForPythonObject(npArr).toJava(npArr);
if (dataType == DataType.BFLOAT16){ if (dataType == DataType.BFLOAT16){
arr = arr.castTo(DataType.FLOAT); arr = arr.castTo(DataType.FLOAT);
} }
assertEquals(arr,arr2); assertEquals(arr,arr2);
} }
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.nd4j.python4j.PythonNumpyBasicTest#params")
public void testExecution(DataType dataType,long[] shape) { public void testExecution(DataType dataType,long[] shape) {
try(PythonGIL pythonGIL = PythonGIL.lock()) { try(PythonGIL pythonGIL = PythonGIL.lock()) {
List<PythonVariable> inputs = new ArrayList<>(); List<PythonVariable> inputs = new ArrayList<>();
@ -124,7 +124,7 @@ public class PythonNumpyBasicTest {
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.nd4j.python4j.PythonNumpyBasicTest#params")
public void testInplaceExecution(DataType dataType,long[] shape) { public void testInplaceExecution(DataType dataType,long[] shape) {
try(PythonGIL pythonGIL = PythonGIL.lock()) { try(PythonGIL pythonGIL = PythonGIL.lock()) {
if (dataType == DataType.BOOL || dataType == DataType.BFLOAT16)return; if (dataType == DataType.BOOL || dataType == DataType.BFLOAT16)return;

View File

@ -1,4 +1,27 @@
/* /*
*
* * ******************************************************************************
* * *
* * *
* * * This program and the accompanying materials are made available under the
* * * terms of the Apache License, Version 2.0 which is available at
* * * https://www.apache.org/licenses/LICENSE-2.0.
* * *
* * * See the NOTICE file distributed with this work for additional
* * * information regarding copyright ownership.
* * * Unless required by applicable law or agreed to in writing, software
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * * License for the specific language governing permissions and limitations
* * * under the License.
* * *
* * * SPDX-License-Identifier: Apache-2.0
* * *****************************************************************************
*
*
*/
package org.nd4j.python4j;/*
* ****************************************************************************** * ******************************************************************************
* * * *
* * * *
@ -61,8 +84,7 @@ public class PythonNumpyCollectionsTest {
}).stream().map(Arguments::of); }).stream().map(Arguments::of);
} }
@Test @MethodSource("org.nd4j.python4j.PythonNumpyCollectionsTest#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testPythonDictFromMap(DataType dataType) throws PythonException { public void testPythonDictFromMap(DataType dataType) throws PythonException {
try(PythonGIL pythonGIL = PythonGIL.lock()) { try(PythonGIL pythonGIL = PythonGIL.lock()) {
@ -84,8 +106,7 @@ public class PythonNumpyCollectionsTest {
} }
@Test @MethodSource("org.nd4j.python4j.PythonNumpyCollectionsTest#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testPythonListFromList(DataType dataType) throws PythonException { public void testPythonListFromList(DataType dataType) throws PythonException {
try(PythonGIL pythonGIL = PythonGIL.lock()) { try(PythonGIL pythonGIL = PythonGIL.lock()) {

View File

@ -1,4 +1,27 @@
/* /*
*
* * ******************************************************************************
* * *
* * *
* * * This program and the accompanying materials are made available under the
* * * terms of the Apache License, Version 2.0 which is available at
* * * https://www.apache.org/licenses/LICENSE-2.0.
* * *
* * * See the NOTICE file distributed with this work for additional
* * * information regarding copyright ownership.
* * * Unless required by applicable law or agreed to in writing, software
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * * License for the specific language governing permissions and limitations
* * * under the License.
* * *
* * * SPDX-License-Identifier: Apache-2.0
* * *****************************************************************************
*
*
*/
package org.nd4j.python4j;/*
* ****************************************************************************** * ******************************************************************************
* * * *
* * * *
@ -18,11 +41,6 @@
* ***************************************************************************** * *****************************************************************************
*/ */
import org.nd4j.python4j.Python;
import org.nd4j.python4j.PythonGC;
import org.nd4j.python4j.PythonGIL;
import org.nd4j.python4j.PythonObject;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;

View File

@ -1,4 +1,27 @@
/* /*
*
* * ******************************************************************************
* * *
* * *
* * * This program and the accompanying materials are made available under the
* * * terms of the Apache License, Version 2.0 which is available at
* * * https://www.apache.org/licenses/LICENSE-2.0.
* * *
* * * See the NOTICE file distributed with this work for additional
* * * information regarding copyright ownership.
* * * Unless required by applicable law or agreed to in writing, software
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * * License for the specific language governing permissions and limitations
* * * under the License.
* * *
* * * SPDX-License-Identifier: Apache-2.0
* * *****************************************************************************
*
*
*/
package org.nd4j.python4j;/*
* ****************************************************************************** * ******************************************************************************
* * * *
* * * *
@ -18,8 +41,6 @@
* ***************************************************************************** * *****************************************************************************
*/ */
import org.nd4j.python4j.*;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;

View File

@ -1,4 +1,27 @@
/* /*
*
* * ******************************************************************************
* * *
* * *
* * * This program and the accompanying materials are made available under the
* * * terms of the Apache License, Version 2.0 which is available at
* * * https://www.apache.org/licenses/LICENSE-2.0.
* * *
* * * See the NOTICE file distributed with this work for additional
* * * information regarding copyright ownership.
* * * Unless required by applicable law or agreed to in writing, software
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * * License for the specific language governing permissions and limitations
* * * under the License.
* * *
* * * SPDX-License-Identifier: Apache-2.0
* * *****************************************************************************
*
*
*/
package org.nd4j.python4j;/*
* ****************************************************************************** * ******************************************************************************
* * * *
* * * *
@ -61,8 +84,7 @@ public class PythonNumpyMultiThreadTest {
} }
@Test @MethodSource("org.nd4j.python4j.PythonNumpyMultiThreadTest#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testMultiThreading1(DataType dataType) throws Throwable { public void testMultiThreading1(DataType dataType) throws Throwable {
final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<Throwable>()); final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<Throwable>());
@ -100,8 +122,7 @@ public class PythonNumpyMultiThreadTest {
} }
@Test @MethodSource("org.nd4j.python4j.PythonNumpyMultiThreadTest#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testMultiThreading2(DataType dataType) throws Throwable { public void testMultiThreading2(DataType dataType) throws Throwable {
final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<>()); final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<>());

View File

@ -1,4 +1,27 @@
/* /*
*
* * ******************************************************************************
* * *
* * *
* * * This program and the accompanying materials are made available under the
* * * terms of the Apache License, Version 2.0 which is available at
* * * https://www.apache.org/licenses/LICENSE-2.0.
* * *
* * * See the NOTICE file distributed with this work for additional
* * * information regarding copyright ownership.
* * * Unless required by applicable law or agreed to in writing, software
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * * License for the specific language governing permissions and limitations
* * * under the License.
* * *
* * * SPDX-License-Identifier: Apache-2.0
* * *****************************************************************************
*
*
*/
package org.nd4j.python4j;/*
* ****************************************************************************** * ******************************************************************************
* * * *
* * * *