/*
 * Decompiled with CFR 0.152.
 */
package com.databricks.spark.sql.perf.mllib.classification;

import com.databricks.spark.sql.perf.MLMetric;
import com.databricks.spark.sql.perf.mllib.BenchmarkAlgorithm;
import com.databricks.spark.sql.perf.mllib.MLBenchContext;
import com.databricks.spark.sql.perf.mllib.OptionImplicits$;
import com.databricks.spark.sql.perf.mllib.ScoringWithEvaluator;
import com.databricks.spark.sql.perf.mllib.TestFromTraining;
import com.databricks.spark.sql.perf.mllib.TrainingSetFromTransformer;
import com.databricks.spark.sql.perf.mllib.data.DataGenerator$;
import java.io.Serializable;
import java.util.Random;
import org.apache.spark.ml.ModelBuilderSSP$;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.classification.NaiveBayes;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Predef$;
import scala.collection.TraversableOnce;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Map;
import scala.collection.mutable.ArrayOps;
import scala.math.Numeric;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;
import scala.runtime.java8.JFunction0;
import scala.runtime.java8.JFunction1;

public final class NaiveBayes$
implements BenchmarkAlgorithm,
TestFromTraining,
TrainingSetFromTransformer,
ScoringWithEvaluator {
    public static NaiveBayes$ MODULE$;

    static {
        new NaiveBayes$();
    }

    @Override
    public final MLMetric score(MLBenchContext ctx, Dataset<Row> testSet, Transformer model) {
        return ScoringWithEvaluator.score$(this, ctx, testSet, model);
    }

    @Override
    public final Dataset<Row> trainingDataSet(MLBenchContext ctx) {
        return TrainingSetFromTransformer.trainingDataSet$(this, ctx);
    }

    @Override
    public final Dataset<Row> testDataSet(MLBenchContext ctx) {
        return TestFromTraining.testDataSet$(this, ctx);
    }

    @Override
    public String name() {
        return BenchmarkAlgorithm.name$(this);
    }

    @Override
    public Map<String, Function0<?>> testAdditionalMethods(MLBenchContext ctx, Transformer transformer) {
        return BenchmarkAlgorithm.testAdditionalMethods$(this, ctx, transformer);
    }

    @Override
    public Dataset<Row> initialData(MLBenchContext ctx) {
        Random rng = ctx.newGenerator();
        int maxFeatureArity = 20;
        int[] featureArity = (int[])((TraversableOnce)RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), OptionImplicits$.MODULE$.oI2I(ctx.params().numFeatures())).map((Function1)(JFunction1.mcII.sp & Serializable & scala.Serializable)x$1 -> 2 + rng.nextInt(maxFeatureArity - 2), IndexedSeq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Int());
        return DataGenerator$.MODULE$.generateMixedFeatures(ctx.sqlContext(), OptionImplicits$.MODULE$.oL2L(ctx.params().numExamples()), ctx.seed(), OptionImplicits$.MODULE$.oI2I(ctx.params().numPartitions()), featureArity);
    }

    @Override
    public Transformer trueModel(MLBenchContext ctx) {
        Random rng = ctx.newGenerator();
        double[] unnormalizedProbs = (double[])((TraversableOnce)RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), OptionImplicits$.MODULE$.oI2I(ctx.params().numClasses())).map((Function1)(JFunction1.mcDI.sp & Serializable & scala.Serializable)x$2 -> rng.nextDouble() + 1.0E-5, IndexedSeq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Double());
        double logProbSum = package$.MODULE$.log(BoxesRunTime.unboxToDouble((Object)new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(unnormalizedProbs)).sum((Numeric)Numeric.DoubleIsFractional$.MODULE$)));
        double[] piArray = (double[])new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(unnormalizedProbs)).map((Function1)(JFunction1.mcDD.sp & Serializable & scala.Serializable)prob -> package$.MODULE$.log(prob) - logProbSum, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double()));
        double currClassProb = 0.7;
        double[][] thetaArray = (double[][])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])Array$.MODULE$.tabulate(OptionImplicits$.MODULE$.oI2I(ctx.params().numClasses()), (Function1 & Serializable & scala.Serializable)i -> NaiveBayes$.$anonfun$trueModel$3(currClassProb, ctx, BoxesRunTime.unboxToInt((Object)i)), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Double.TYPE))))).map((Function1 & Serializable & scala.Serializable)x$3 -> (double[])new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(x$3)).map((Function1)(JFunction1.mcDD.sp & Serializable & scala.Serializable)x -> package$.MODULE$.log(x), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double())), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Double.TYPE))));
        Vector pi = Vectors$.MODULE$.dense(piArray);
        DenseMatrix theta = new DenseMatrix(OptionImplicits$.MODULE$.oI2I(ctx.params().numClasses()), OptionImplicits$.MODULE$.oI2I(ctx.params().numFeatures()), (double[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])thetaArray)).flatten((Function1 & Serializable & scala.Serializable)xs -> Predef$.MODULE$.wrapDoubleArray(xs), ClassTag$.MODULE$.Double()), true);
        return ModelBuilderSSP$.MODULE$.newNaiveBayesModel(pi, (Matrix)theta);
    }

    @Override
    public PipelineStage getPipelineStage(MLBenchContext ctx) {
        return new NaiveBayes().setSmoothing(OptionImplicits$.MODULE$.oD2D(ctx.params().smoothing()));
    }

    @Override
    public Evaluator evaluator(MLBenchContext ctx) {
        return new MulticlassClassificationEvaluator();
    }

    public static final /* synthetic */ double[] $anonfun$trueModel$3(double currClassProb$1, MLBenchContext ctx$1, int i) {
        double baseProbMass = (1.0 - currClassProb$1) / (double)(OptionImplicits$.MODULE$.oI2I(ctx$1.params().numFeatures()) - 1);
        double[] probs = (double[])Array$.MODULE$.fill(OptionImplicits$.MODULE$.oI2I(ctx$1.params().numFeatures()), (Function0)(JFunction0.mcD.sp & Serializable & scala.Serializable)() -> baseProbMass, ClassTag$.MODULE$.Double());
        probs[i] = currClassProb$1;
        return probs;
    }

    private NaiveBayes$() {
        MODULE$ = this;
        BenchmarkAlgorithm.$init$(this);
        TestFromTraining.$init$(this);
        TrainingSetFromTransformer.$init$(this);
        ScoringWithEvaluator.$init$(this);
    }
}

