/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.rules.ZeroR;
import weka.classifiers.trees.DecisionStump;
import weka.core.AdditionalMeasureProducer;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.UnsupportedClassTypeException;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class AdditiveRegression
extends Classifier
implements OptionHandler,
AdditionalMeasureProducer,
WeightedInstancesHandler {
    protected Classifier m_Classifier = new DecisionStump();
    private int m_classIndex;
    protected double m_shrinkage = 1.0;
    private FastVector m_additiveModels = new FastVector();
    private boolean m_debug = false;
    protected int m_maxModels = 10;

    public String globalInfo() {
        return " Meta classifier that enhances the performance of a regression base classifier. Each iteration fits a model to the residuals left by the classifier on the previous iteration. Prediction is accomplished by adding the predictions of each classifier. Reducing the shrinkage (learning rate) parameter helps prevent overfitting and has a smoothing effect but increases the learning time.  For more information see: Friedman, J.H. (1999). Stochastic Gradient Boosting. Technical Report Stanford University. http://www-stat.stanford.edu/~jhf/ftp/stobst.ps.";
    }

    public AdditiveRegression() {
        this(new DecisionStump());
    }

    public AdditiveRegression(Classifier classifier) {
        this.m_Classifier = classifier;
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>(4);
        vector.addElement(new Option("\tFull class name of classifier to use, followed\n\tby scheme options. (required)\n\teg: \"weka.classifiers.bayes.NaiveBayes -D\"", "W", 1, "-W <classifier specification>"));
        vector.addElement(new Option("\tSpecify shrinkage rate. (default=1.0, ie. no shrinkage)\n", "S", 1, "-S"));
        vector.addElement(new Option("\tTurn on debugging output.", "D", 0, "-D"));
        vector.addElement(new Option("\tSpecify max models to generate. (default = 10, ie. no max; keep going until error reduction threshold is reached)\n", "M", 1, "-M"));
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        this.setDebug(Utils.getFlag('D', stringArray));
        String string = Utils.getOption('W', stringArray);
        if (string.length() == 0) {
            throw new Exception("A classifier must be specified with the -w option.");
        }
        String[] stringArray2 = Utils.splitOptions(string);
        if (stringArray2.length == 0) {
            throw new Exception("Invalid classifier specification string");
        }
        String string2 = stringArray2[0];
        stringArray2[0] = "";
        this.setClassifier(Classifier.forName(string2, stringArray2));
        String string3 = Utils.getOption('S', stringArray);
        if (string3.length() != 0) {
            Double d = Double.valueOf(string3);
            this.setShrinkage(d);
        }
        if ((string3 = Utils.getOption('M', stringArray)).length() != 0) {
            this.setMaxModels(Integer.parseInt(string3));
        }
        Utils.checkForRemainingOptions(stringArray);
    }

    public String[] getOptions() {
        String[] stringArray = new String[7];
        int n = 0;
        if (this.getDebug()) {
            stringArray[n++] = "-D";
        }
        stringArray[n++] = "-W";
        stringArray[n++] = "" + this.getClassifierSpec();
        stringArray[n++] = "-S";
        stringArray[n++] = "" + this.getShrinkage();
        stringArray[n++] = "-M";
        stringArray[n++] = "" + this.getMaxModels();
        while (n < stringArray.length) {
            stringArray[n++] = "";
        }
        return stringArray;
    }

    public String debugTipText() {
        return "Turn on debugging output";
    }

    public void setDebug(boolean bl) {
        this.m_debug = bl;
    }

    public boolean getDebug() {
        return this.m_debug;
    }

    public String classifierTipText() {
        return "Classifier to use";
    }

    public void setClassifier(Classifier classifier) {
        this.m_Classifier = classifier;
    }

    public Classifier getClassifier() {
        return this.m_Classifier;
    }

    protected String getClassifierSpec() {
        Classifier classifier = this.getClassifier();
        if (classifier instanceof OptionHandler) {
            return classifier.getClass().getName() + " " + Utils.joinOptions(classifier.getOptions());
        }
        return classifier.getClass().getName();
    }

    public String maxModelsTipText() {
        return "Max models to generate. <= 0 indicates no maximum, ie. continue until error reduction threshold is reached.";
    }

    public void setMaxModels(int n) {
        this.m_maxModels = n;
    }

    public int getMaxModels() {
        return this.m_maxModels;
    }

    public String shrinkageTipText() {
        return "Shrinkage rate. Smaller values help prevent overfitting and have a smoothing effect (but increase learning time). Default = 1.0, ie. no shrinkage.";
    }

    public void setShrinkage(double d) {
        this.m_shrinkage = d;
    }

    public double getShrinkage() {
        return this.m_shrinkage;
    }

    public void buildClassifier(Instances instances) throws Exception {
        int n;
        this.m_additiveModels = new FastVector();
        if (this.m_Classifier == null) {
            throw new Exception("No base classifiers have been set!");
        }
        if (instances.classAttribute().isNominal()) {
            throw new UnsupportedClassTypeException("Class must be numeric!");
        }
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        this.m_classIndex = instances2.classIndex();
        double d = 0.0;
        double d2 = 0.0;
        ZeroR zeroR = new ZeroR();
        zeroR.buildClassifier(instances2);
        this.m_additiveModels.addElement(zeroR);
        instances2 = this.residualReplace(instances2, zeroR, false);
        for (n = 0; n < instances2.numInstances(); ++n) {
            d += instances2.instance(n).weight() * instances2.instance(n).classValue() * instances2.instance(n).classValue();
        }
        if (this.m_debug) {
            System.err.println("Sum of squared residuals (predicting the mean) : " + d);
        }
        n = 0;
        do {
            d2 = d;
            Classifier classifier = Classifier.makeCopies(this.m_Classifier, 1)[0];
            classifier.buildClassifier(instances2);
            this.m_additiveModels.addElement(classifier);
            instances2 = this.residualReplace(instances2, classifier, true);
            d = 0.0;
            for (int i = 0; i < instances2.numInstances(); ++i) {
                d += instances2.instance(i).weight() * instances2.instance(i).classValue() * instances2.instance(i).classValue();
            }
            if (!this.m_debug) continue;
            System.err.println("Sum of squared residuals : " + d);
        } while (d2 - d > Utils.SMALL && (this.m_maxModels <= 0 || ++n < this.m_maxModels));
    }

    public double classifyInstance(Instance instance) throws Exception {
        double d = 0.0;
        for (int i = 0; i < this.m_additiveModels.size(); ++i) {
            Classifier classifier = (Classifier)this.m_additiveModels.elementAt(i);
            double d2 = classifier.classifyInstance(instance);
            if (i > 0) {
                d2 *= this.getShrinkage();
            }
            d += d2;
        }
        return d;
    }

    private Instances residualReplace(Instances instances, Classifier classifier, boolean bl) throws Exception {
        Instances instances2 = new Instances(instances);
        for (int i = 0; i < instances2.numInstances(); ++i) {
            double d = classifier.classifyInstance(instances2.instance(i));
            if (bl) {
                d *= this.getShrinkage();
            }
            double d2 = instances2.instance(i).classValue() - d;
            instances2.instance(i).setClassValue(d2);
        }
        return instances2;
    }

    public Enumeration enumerateMeasures() {
        Vector<String> vector = new Vector<String>(1);
        vector.addElement("measureNumIterations");
        return vector.elements();
    }

    public double getMeasure(String string) {
        if (string.compareToIgnoreCase("measureNumIterations") == 0) {
            return this.measureNumIterations();
        }
        throw new IllegalArgumentException(string + " not supported (AdditiveRegression)");
    }

    public double measureNumIterations() {
        return this.m_additiveModels.size();
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        if (this.m_additiveModels.size() == 0) {
            return "Classifier hasn't been built yet!";
        }
        stringBuffer.append("Additive Regression\n\n");
        stringBuffer.append("Base classifier " + this.getClassifier().getClass().getName() + "\n\n");
        stringBuffer.append("" + this.m_additiveModels.size() + " models generated.\n");
        for (int i = 0; i < this.m_additiveModels.size(); ++i) {
            stringBuffer.append("\nModel number " + i + "\n\n" + this.m_additiveModels.elementAt(i) + "\n");
        }
        return stringBuffer.toString();
    }

    public static void main(String[] stringArray) {
        try {
            System.out.println(Evaluation.evaluateModel(new AdditiveRegression(), stringArray));
        }
        catch (Exception exception) {
            System.err.println(exception.getMessage());
        }
    }
}

