/*
 *
 *    ******************************************************************************
 *    *
 *    * This program and the accompanying materials are made available under the
 *    * terms of the Apache License, Version 2.0 which is available at
 *    * https://www.apache.org/licenses/LICENSE-2.0.
 *    *
 *    *  See the NOTICE file distributed with this work for additional
 *    *  information regarding copyright ownership.
 *    * Unless required by applicable law or agreed to in writing, software
 *    * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *    * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *    * License for the specific language governing permissions and limitations
 *    * under the License.
 *    *
 *    * SPDX-License-Identifier: Apache-2.0
 *    *****************************************************************************
 *
 */
package net.brutex.spark;

import com.fasterxml.jackson.core.Version;
import lombok.extern.slf4j.Slf4j;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.SparkSession;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.filter.FilterInvalidValues;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable;
import org.datavec.spark.transform.SparkTransformExecutor;
import org.datavec.spark.transform.misc.StringToWritablesFunction;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.datavec.DataVecDataSetFunction;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster;
import org.deeplearning4j.ui.api.UIServer;
import org.junit.jupiter.api.*;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;

import java.io.File;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

/**
 * Tests for new Spark Word2Vec implementation
 *
 * @author raver119@gmail.com
 */
@Slf4j
@Tag("integration")
public class BrianTest2 /*extends BaseDL4JTest*/ {
    static {
        String OS = System.getProperty("os.name").toLowerCase();

        if (OS.contains("win")) {
            System.setProperty("hadoop.home.dir", Paths.get("c:\\java\\winutils").toAbsolutePath().toString());
        } else {
            System.setProperty("hadoop.home.dir", "/");
        }
    }

    public long getTimeoutMilliseconds() {
        return 400000L;
    }

    private JavaSparkContext sc;


    /*
    @BeforeAll
    public void loadData() {


        /*
        sc.addFile("https://www.openml.org/data/get_csv/1595261/phpMawTba");
        org.apache.hadoop.fs.FileSystem hdfs = FileSystem.get( sc.hadoopConfiguration());
        try {
            String file = SparkFiles.get("phpMawTba");
            Path target = new Path("/user/brian/" + "mydata.csv");
            //Apache Commons
            FileUtils.copyFile(new File(file), hdfs.create(target));
        } catch (IOException e) {
            e.printStackTrace();
        }


    }
*/

    @BeforeEach
    public void setUp() throws Exception {
        log.info("Running @BeforeEach scope");
        System.setProperty("hadoop.home.dir", Paths.get("c:\\java\\winutils").toAbsolutePath().toString());
        Version version = com.fasterxml.jackson.databind.cfg.PackageVersion.VERSION;
        System.out.println("Jackson version found: " + version);
        System.out.println(System.getProperty("java.vm.name")+"\n"+System.getProperty("java.runtime.version"));

        SparkConf sparkConf = new SparkConf()
                .setMaster("spark://10.5.5.200:7077")
                .setAppName("Brian3")
                .set("spark.driver.bindAddress", "10.5.5.145")
                .set("spark.network.timeout", "240000")
                .set("spark.driver.host", "10.5.5.145")
                .set("spark.deploy.mode", "client")
                .set("spark.executor.memory", "4g")
                .set("spark.cores.max", "2")
                .set("spark.worker.cleanup.enabled", "false")
                .set("spark.driver.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
                .set("spark.executor.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
                .set("spark.hadoop.fs.defaultFS", "hdfs://10.5.5.200:9000");

        SparkSession spark = SparkSession.builder()
                .master("spark://10.5.5.200:7077")
                .appName("BrianTest2")
                .config(sparkConf)
                .getOrCreate();

        this.sc =  JavaSparkContext.fromSparkContext(spark.sparkContext());

        /*
        Whatever is in classpath (driver), is added to the Spark Executors
         */
        final String clpath = System.getProperty("java.class.path");
        log.info("java.class.path=\r\n{}\r\n", clpath);
        final String separator = System.getProperty("path.separator");
        final String[] a = clpath.split(separator);
        for(String s : a) {
            File f = new File(s);
            if(f.exists() && f.isFile() && s.endsWith(".jar")) {
                log.info("Adding jar to SparkContext '{}'.", f.getName());
                this.sc.addJar(s);
            }
        }
    }

    @AfterEach
    public void tearDown() throws Exception {
        if(sc!=null) this.sc.stop();
        UIServer.stopInstance();
    }

    @Test
    public void testStringsTokenization1() throws Exception {

        final JavaRDD<String> rdd = sc.textFile("hdfs://10.5.5.200:9000/user/zeppelin/cities_full.csv.gz");
        //shrink for Test
        //List<String> list = Arrays.asList(new String[]{"asdsad", "asdasdasd", "asdasdasd", "3easdasd"});
        //JavaRDD<String> rdd = sc.parallelize(list);

       // rdd = rdd.sample(true, 1.0, 1);
        log.info("Datenmenge: " + rdd.count());
            log.info("Sample: " + rdd.top(3));

        Assertions.assertEquals(146889, rdd.count());
    }

    @Test
    public void testSchemaCreation() throws Exception {
        log.info(System.getProperty("java.vm.name")+"\n"+System.getProperty("java.runtime.version"));
        final JavaRDD<String> rdd = sc.textFile("hdfs://10.5.5.200:9000/user/zeppelin/cities_full.csv.gz");
        rdd.cache();

        JavaRDD<String> cities = rdd.map( (Function<String, String>) line -> {
                    return line.split(",")[1];
        }).cache();

        JavaRDD<String> stateCodeList = rdd.map( (Function<String, String>) line -> {
            return line.split(",")[2];
        }).cache();

        JavaRDD<String> countryCodeList = rdd.map( (Function<String, String>) line -> {
            return line.split(",")[3];
        }).cache();


        CSVRecordReader recordReader = new CSVRecordReader(0, ',');
        JavaRDD<List<Writable>> convertedRDD = rdd.map((Function<String, List<Writable>>) s -> {
            return new StringToWritablesFunction( recordReader).call(s);
        });

        //Source Schema
        Schema inputSchema = new Schema.Builder()
                .addColumnLong("city_id")
                .addColumnsString("city_name", "state_code", "country_code")
                .addColumnsString("country_full")
                .addColumnsDouble("lat", "lon")
                .build();

        //Running Transformation
        /*
        TransformProcess tp  = new TransformProcess.Builder(inputSchema)
                .removeColumns("country_full", "lat", "lon")
                .addConstantIntegerColumn("dummy_spalte", 1)
                .stringToCategorical("state_code", stateCodeList.distinct().collect())
                .stringToCategorical("country_code", countryCodeList.distinct().collect())
                .stringToCategorical("city_name", cities.distinct().collect())
                .filter(new FilterInvalidValues())
                .categoricalToOneHot("city_name")
                .categoricalToOneHot("state_code")
                .categoricalToOneHot("country_code")
                .build();
        */
        TransformProcess tp = new TransformProcess.Builder(inputSchema)
                .removeAllColumnsExceptFor("country_code", "lat", "lon")
                .stringToCategorical("country_code", Arrays.asList("GR", "FR", "DE", "CH"))
                .filter(new FilterInvalidValues())
                .categoricalToOneHot("country_code")
                .build();

        //log.info("Final Schema: " +tp.getFinalSchema().toString());
        //Execute Transformation Process
        //convertedRDD.repartition(1);
        //convertedRDD.cache();
        JavaRDD<List<Writable>> processedData = SparkTransformExecutor.execute(convertedRDD, tp);
        //processedData.repartition(1);
        //processedData.cache();
        //log.info("Datenmenge nach processing: " + processedData.count());


        //Vectorisieren
        int labelIndex = 0; //in welcher Spalte ist das Label
        int numLabels = 4; //Anzahl der Klassen 0-236 = 237 Werte

        DataVecDataSetFunction datavecFunction = new DataVecDataSetFunction(labelIndex, numLabels, false);
        JavaRDD<DataSet> rddDataSet = processedData.map(datavecFunction);
        log.info("rddDataset: " + rddDataSet.toDebugString());
        Random rand = new Random();
        rddDataSet.sortBy( (Function<DataSet, Double>) s -> {return rand.nextDouble(); }, true, 8);

        //og.info("Sample: " + rddDataSet.sample(false, 0.005, 0).collect());

        /* Skip, this will save each record one by one to hdfs
        */
        //Now save this hard work
        /*
        int miniBatchSize = 1; //Minibatch size of the saved DataSet objects
        final String exportPath = "hdfs://10.5.5.200:9000/user/brian/data";
        JavaRDD<String> paths = rddDataSet.mapPartitionsWithIndex(
                new BatchAndExportDataSetsFunction(miniBatchSize, exportPath),
                true)
                ;
        paths.collect();
        */


        // Configure distributed training required for gradient sharing implementation
        VoidConfiguration conf = VoidConfiguration.builder()
                .unicastPort(40123)             //Port that workers will use to communicate. Use any free port
                //.networkMask("10.0.0.0/16")     //Network mask for communication. Examples 10.0.0.0/24, or 192.168.0.0/16 etc
                .controllerAddress("10.5.5.145")
                .build();

//Create the TrainingMaster instance
        TrainingMaster trainingMaster = new SharedTrainingMaster.Builder(conf, 1000)
                .batchSizePerWorker(20000) //Batch size for training
                .updatesThreshold(1e-3)                 //Update threshold for quantization/compression. See technical explanation page
                .workersPerNode(1)      // equal to number of GPUs. For CPUs: use 1; use > 1 for large core count
                .exportDirectory("/user/brian/")
                           .build();

        //Create Trainingmaster
/*
        TrainingMaster trainingMaster = new ParameterAveragingTrainingMaster.Builder(4)
                .rddTrainingApproach(RDDTrainingApproach.Direct) //when "export", tries to save everything first
                .collectTrainingStats(false).build();
  */
        /*
        TrainingMaster tm = new SharedTrainingMaster.Builder(voidConfiguration, minibatch)
                .thresholdAlgorithm(new AdaptiveThresholdAlgorithm(this.gradientThreshold))
                .residualPostProcessor(new ResidualClippingPostProcessor(5, 5))
                .build();
*/
        //Define Network

        MultiLayerConfiguration multiLayerConfiguration = new NeuralNetConfiguration.Builder()
                .seed(123)
                .updater(new Nesterovs(0.1, 0.9))
                .list()
                    .layer(0, new DenseLayer.Builder().nIn(5).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).l2(0.001).build())
                    .layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
                    //.layer(2, new DenseLayer.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
                    .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(20).nOut(4).weightInit(WeightInit.XAVIER).activation(Activation.SIGMOID).build())
                .build();

        //Define SparkNet
        SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, multiLayerConfiguration, trainingMaster);


        JavaRDD<DataSet>[] split = rddDataSet.randomSplit(new double[] {0.9, 0.1}, 123);
        //JavaRDD<DataSet> trainingData = split[0];
        JavaRDD<DataSet> trainingData = rddDataSet;
        JavaRDD<DataSet> testData = split[1];

        //Run Training on subset
        for(int i =0; i<4; i++) {
            sparkNet.fit(trainingData);
        }

        //Evaluieren
        MultiLayerNetwork finalNet = sparkNet.getNetwork();

        //Speichern
        Configuration hconf = sc.hadoopConfiguration();
        hconf.set("hadoop.tmp.dir", "/user/brian/tmp");
        FileSystem fs = FileSystem.get(hconf);
        Path p = new Path("hdfs://10.5.5.200:9000/user/brian/model");
        //fs.mkdirs(p);
        //ModelSerializer.writeModel(finalNet, fs.create(p), true );

        Evaluation eval = new Evaluation(4); // outputNum = 10: number of output classes
        Iterator<DataSet> iter = testData.toLocalIterator();
        log.info("testData has " + testData.count() + " DataSets");
        while(iter.hasNext()){
            DataSet next = iter.next();
            //log.info("getFeatures " + next.getFeatures() );
            INDArray output = finalNet.output(next.getFeatures()); //get the networks prediction
            //log.info("output "+ output.toStringFull());
            eval.eval(next.getLabels(), output); //check the prediction against the true class
            //log.info("Predict " + finalNet.predict(next));
        }
        log.info("Evaluation stats: " + eval.stats());

    }

}