/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.data;

import gov.sandia.cognition.collection.DefaultMultiCollection;
import gov.sandia.cognition.collection.MultiCollection;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.DefaultWeightedInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.TargetEstimatePair;
import gov.sandia.cognition.learning.data.WeightedInputOutputPair;
import gov.sandia.cognition.learning.data.WeightedTargetEstimatePair;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorUtil;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ;
import gov.sandia.cognition.statistics.DataHistogram;
import gov.sandia.cognition.statistics.distribution.MapBasedDataHistogram;
import gov.sandia.cognition.util.DefaultPair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class DatasetUtil {
    public static ArrayList<Vector> appendBias(Collection<? extends Vector> dataset) {
        return DatasetUtil.appendBias(dataset, 1.0);
    }

    public static ArrayList<Vector> appendBias(Collection<? extends Vector> dataset, double biasValue) {
        ArrayList<Vector> biasDataset = new ArrayList<Vector>(dataset.size());
        Vector bias = VectorFactory.getDefault().copyValues(biasValue);
        for (Vector vector : dataset) {
            biasDataset.add(vector.stack(bias));
        }
        return biasDataset;
    }

    public static ArrayList<ArrayList<InputOutputPair<Double, Double>>> decoupleVectorPairDataset(Collection<? extends InputOutputPair<? extends Vector, ? extends Vector>> dataset) {
        int numSamples = dataset.size();
        int M = dataset.iterator().next().getInput().getDimensionality();
        ArrayList<ArrayList<InputOutputPair<Double, Double>>> retval = new ArrayList<ArrayList<InputOutputPair<Double, Double>>>(M);
        for (int i = 0; i < M; ++i) {
            retval.add(new ArrayList(numSamples));
        }
        for (InputOutputPair<? extends Vector, ? extends Vector> inputOutputPair : dataset) {
            if (inputOutputPair.getInput().getDimensionality() != M || inputOutputPair.getOutput().getDimensionality() != M) {
                throw new IllegalArgumentException("All input-output Vectors must have same dimension!");
            }
            for (int i = 0; i < M; ++i) {
                DefaultInputOutputPair<Double, Double> rowPair;
                double x = inputOutputPair.getInput().getElement(i);
                double y = inputOutputPair.getOutput().getElement(i);
                if (inputOutputPair instanceof WeightedInputOutputPair) {
                    double weight = ((WeightedInputOutputPair)inputOutputPair).getWeight();
                    rowPair = new DefaultWeightedInputOutputPair<Double, Double>(x, y, weight);
                } else {
                    rowPair = new DefaultInputOutputPair<Double, Double>(x, y);
                }
                retval.get(i).add(rowPair);
            }
        }
        return retval;
    }

    public static ArrayList<ArrayList<Double>> decoupleVectorDataset(Collection<? extends Vector> dataset) {
        int M = dataset.iterator().next().getDimensionality();
        int num = dataset.size();
        ArrayList<ArrayList<Double>> decoupledDatasets = new ArrayList<ArrayList<Double>>(M);
        for (int i = 0; i < M; ++i) {
            decoupledDatasets.add(new ArrayList(num));
        }
        for (Vector vector : dataset) {
            if (M != vector.getDimensionality()) {
                throw new IllegalArgumentException("All vectors in the dataset must be the same size");
            }
            for (int i = 0; i < M; ++i) {
                decoupledDatasets.get(i).add(vector.getElement(i));
            }
        }
        return decoupledDatasets;
    }

    public static <DataType> DefaultPair<LinkedList<DataType>, LinkedList<DataType>> splitDatasets(Collection<? extends InputOutputPair<? extends DataType, Boolean>> data) {
        LinkedList<DataType> dtrue = new LinkedList<DataType>();
        LinkedList<DataType> dfalse = new LinkedList<DataType>();
        for (InputOutputPair<DataType, Boolean> pair : data) {
            if (pair.getOutput().booleanValue()) {
                dtrue.add(pair.getInput());
                continue;
            }
            dfalse.add(pair.getInput());
        }
        return DefaultPair.create(dtrue, dfalse);
    }

    public static <InputType, CategoryType> Map<CategoryType, List<InputType>> splitOnOutput(Iterable<? extends InputOutputPair<? extends InputType, ? extends CategoryType>> data) {
        LinkedHashMap<CategoryType, ArrayList<InputType>> result = new LinkedHashMap<CategoryType, ArrayList<InputType>>();
        for (InputOutputPair<InputType, CategoryType> example : data) {
            CategoryType category = example.getOutput();
            ArrayList<InputType> examplesForCategory = (ArrayList<InputType>)result.get(category);
            if (examplesForCategory == null) {
                examplesForCategory = new ArrayList<InputType>();
                result.put(category, examplesForCategory);
            }
            examplesForCategory.add(example.getInput());
        }
        return result;
    }

    public static Matrix computeOuterProductDataMatrix(ArrayList<? extends Vector> data) {
        int M = data.iterator().next().getDimensionality();
        SparseMatrix XXt = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(M, M);
        for (int j = 0; j < M; ++j) {
            for (int i = 0; i < M; ++i) {
                double sum = 0.0;
                for (int k = 0; k < data.size(); ++k) {
                    Vector dk = data.get(k);
                    sum += dk.getElement(i) * dk.getElement(j);
                }
                if (sum == 0.0) continue;
                XXt.setElement(i, j, sum);
            }
        }
        return XXt;
    }

    public static double computeOutputMean(Collection<? extends InputOutputPair<?, ? extends Number>> data) {
        if (data == null) {
            return 0.0;
        }
        double sum = 0.0;
        int count = 0;
        for (InputOutputPair<?, Number> example : data) {
            sum += example.getOutput().doubleValue();
            ++count;
        }
        if (count <= 0) {
            return 0.0;
        }
        return sum / (double)count;
    }

    public static double computeWeightedOutputMean(Collection<? extends InputOutputPair<?, ? extends Number>> data) {
        if (data == null || (double)data.size() <= 0.0) {
            return 0.0;
        }
        double sum = 0.0;
        double weightSum = 0.0;
        for (InputOutputPair<?, Number> example : data) {
            double weight = DatasetUtil.getWeight(example);
            sum += weight * example.getOutput().doubleValue();
            weightSum += weight;
        }
        if (weightSum == 0.0) {
            return 0.0;
        }
        return sum / weightSum;
    }

    public static double computeOutputVariance(Collection<? extends InputOutputPair<?, ? extends Number>> data) {
        if (data == null) {
            return 0.0;
        }
        int count = data.size();
        if (count <= 0) {
            return 0.0;
        }
        double mean = DatasetUtil.computeOutputMean(data);
        double sum = 0.0;
        for (InputOutputPair<?, Number> example : data) {
            double difference = example.getOutput().doubleValue() - mean;
            sum += difference * difference;
        }
        double variance = sum / (double)count;
        return variance;
    }

    public static <OutputType> Set<OutputType> findUniqueOutputs(Iterable<? extends InputOutputPair<?, ? extends OutputType>> data) {
        LinkedHashSet<OutputType> outputs = new LinkedHashSet<OutputType>();
        if (data != null) {
            for (InputOutputPair<?, OutputType> example : data) {
                outputs.add(example.getOutput());
            }
        }
        return outputs;
    }

    public static <OutputType> DataHistogram<OutputType> countOutputValues(Iterable<? extends InputOutputPair<?, ? extends OutputType>> data) {
        MapBasedDataHistogram<OutputType> outputs = new MapBasedDataHistogram<OutputType>();
        if (data != null) {
            for (InputOutputPair<?, OutputType> example : data) {
                outputs.add(example.getOutput());
            }
        }
        return outputs;
    }

    public static <InputType> List<InputType> inputsList(Iterable<? extends InputOutputPair<? extends InputType, ?>> data) {
        ArrayList<InputType> inputs = new ArrayList<InputType>();
        if (data != null) {
            for (InputOutputPair<InputType, ?> example : data) {
                inputs.add(example.getInput());
            }
        }
        return inputs;
    }

    public static <EntryType> MultiCollection<EntryType> asMultiCollection(Collection<EntryType> collection) {
        if (collection instanceof MultiCollection) {
            return (MultiCollection)collection;
        }
        return new DefaultMultiCollection<EntryType>(Collections.singletonList(collection));
    }

    public static Collection<Vector> asVectorCollection(Collection<? extends Vectorizable> collection) {
        ArrayList<Vector> result = new ArrayList<Vector>(collection.size());
        for (Vectorizable vectorizable : collection) {
            result.add(vectorizable.convertToVector());
        }
        return result;
    }

    public static int getInputDimensionality(Iterable<? extends InputOutputPair<? extends Vectorizable, ?>> data) {
        if (data != null) {
            for (InputOutputPair<Vectorizable, ?> example : data) {
                Vector vector;
                Vectorizable input;
                if (example == null || (input = example.getInput()) == null || (vector = input.convertToVector()) == null) continue;
                return vector.getDimensionality();
            }
        }
        return -1;
    }

    public static void assertInputDimensionalitiesAllEqual(Iterable<? extends InputOutputPair<? extends Vectorizable, ?>> data) {
        DatasetUtil.assertInputDimensionalitiesAllEqual(data, DatasetUtil.getInputDimensionality(data));
    }

    public static void assertInputDimensionalitiesAllEqual(Iterable<? extends InputOutputPair<? extends Vectorizable, ?>> data, int dimensionality) {
        if (data != null) {
            for (InputOutputPair<Vectorizable, ?> example : data) {
                Vector vector;
                Vectorizable input;
                if (example == null || (input = example.getInput()) == null || (vector = input.convertToVector()) == null) continue;
                vector.assertDimensionalityEquals(dimensionality);
            }
        }
    }

    public static int getDimensionality(Iterable<? extends Vectorizable> data) {
        if (data != null) {
            for (Vectorizable vectorizable : data) {
                Vector vector;
                if (vectorizable == null || (vector = vectorizable.convertToVector()) == null) continue;
                return vector.getDimensionality();
            }
        }
        return -1;
    }

    public static void assertDimensionalitiesAllEqual(Iterable<? extends Vectorizable> data) {
        VectorUtil.assertDimensionalitiesAllEqual(data, DatasetUtil.getDimensionality(data));
    }

    public static double getWeight(InputOutputPair<?, ?> pair) {
        if (pair instanceof WeightedInputOutputPair) {
            return ((WeightedInputOutputPair)pair).getWeight();
        }
        return 1.0;
    }

    public static double getWeight(TargetEstimatePair<?, ?> pair) {
        if (pair instanceof WeightedTargetEstimatePair) {
            return ((WeightedTargetEstimatePair)pair).getWeight();
        }
        return 1.0;
    }
}

