cavis/brutex-extended-tests/src/test/java/net/brutex/gan/GANVisualizationUtils.java

73 lines
2.3 KiB
Java

/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan;
import org.nd4j.linalg.api.ndarray.INDArray;
import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
public class GANVisualizationUtils {
public static JFrame initFrame() {
JFrame frame = new JFrame();
frame.setTitle("Viz");
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
frame.setLayout(new BorderLayout());
return frame;
}
public static JPanel initPanel(JFrame frame, int numSamples) {
JPanel panel = new JPanel();
panel.setLayout(new GridLayout(numSamples / 3, 1, 8, 8));
frame.add(panel, BorderLayout.CENTER);
frame.setVisible(true);
return panel;
}
public static void visualize(INDArray[] samples, JFrame frame, JPanel panel) {
panel.removeAll();
for (int i = 0; i < samples.length; i++) {
panel.add(getImage(samples[i]));
}
frame.revalidate();
frame.pack();
}
private static JLabel getImage(INDArray tensor) {
BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
for (int i = 0; i < 784; i++) {
int pixel = (int) (((tensor.getDouble(i) + 1) * 2) * 255);
bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
ImageIcon scaled = new ImageIcon(imageScaled);
return new JLabel(scaled);
}
}