package nets;

import java.util.Map;

/* loaded from: input_file:agentDemonstrator/nets/BackProp.class */
public class BackProp {
    private TwoLayerNet net;
    private int numOfInputs;
    private int numOfHiddenUnits;
    private int numOfOutputs;

    public BackProp(int i, int i2, int i3) {
        this.net = new TwoLayerNet(i, i2, i3);
        this.numOfInputs = i + 1;
        this.numOfHiddenUnits = i2 + 1;
        this.numOfOutputs = i3;
    }

    public TwoLayerNet backProp(int i, double d, Map map, int i2, Map map2, int i3) {
        if (i2 > 0) {
            System.out.println(this.net);
        }
        for (int i4 = 0; i4 < i; i4++) {
            backPropOneEpoch(d, map, i2, map2, i3);
        }
        if (i2 == 1) {
            System.out.println(this.net);
        }
        if (i3 == 1) {
            test(map2, i2);
        }
        return this.net;
    }

    public TwoLayerNet backProp(double d, double d2, Map map, int i, Map map2, int i2) {
        if (i > 0) {
            System.out.println(this.net);
        }
        int i3 = 0;
        do {
            backPropOneEpoch(d2, map, i, map2, i2);
            i3++;
        } while (test(map, 0) > d);
        if (i == 1) {
            System.out.println(new StringBuffer().append("Training converged after ").append(i3).append(" epochs.\n").toString());
            System.out.println(this.net);
        }
        if (i2 == 1) {
            test(map2, i);
        }
        return this.net;
    }

    public TwoLayerNet backPropOneEpoch(double d, Map map, int i, Map map2, int i2) {
        for (Map.Entry entry : map.entrySet()) {
            backPropOneExample(d, (double[]) entry.getKey(), (double[]) entry.getValue(), i, map2, i2);
        }
        if (i == 2) {
            System.out.println(this.net);
        }
        if (i2 == 2) {
            test(map2, i);
        }
        return this.net;
    }

    public TwoLayerNet backPropOneExample(double d, double[] dArr, double[] dArr2, int i, Map map, int i2) {
        double[] dArr3 = new double[this.numOfOutputs];
        double[] dArr4 = new double[this.numOfOutputs];
        double[] dArr5 = new double[this.numOfHiddenUnits];
        double[] activate = this.net.activate(dArr);
        for (int i3 = 0; i3 < this.numOfOutputs; i3++) {
            dArr3[i3] = dArr2[i3] - activate[i3];
            dArr4[i3] = ((OutputLayerTLU) this.net.getOutputUnit(i3)).getGradient() * dArr3[i3];
            for (int i4 = 0; i4 < this.numOfHiddenUnits; i4++) {
                this.net.setOutputLayerWeight(i4, i3, this.net.getOutputLayerWeight(i4, i3) + (d * this.net.getHiddenUnit(i4).getOutput() * dArr4[i3]));
                int i5 = i4;
                dArr5[i5] = dArr5[i5] + (this.net.getOutputLayerWeight(i4, i3) * dArr4[i3]);
            }
        }
        for (int i6 = 1; i6 < this.numOfHiddenUnits; i6++) {
            int i7 = i6;
            dArr5[i7] = dArr5[i7] * ((HiddenLayerTLU) this.net.getHiddenUnit(i6)).getGradient();
            this.net.setHiddenLayerWeight(0, i6, this.net.getHiddenLayerWeight(0, i6) + (d * 1.0d * dArr5[i6]));
            for (int i8 = 1; i8 < this.numOfInputs; i8++) {
                this.net.setHiddenLayerWeight(i8, i6, this.net.getHiddenLayerWeight(i8, i6) + (d * dArr[i8 - 1] * dArr5[i6]));
            }
        }
        if (i == 3) {
            System.out.println(this.net);
        }
        if (i2 == 3) {
            test(map, i);
        }
        return this.net;
    }

    public double test(Map map, int i) {
        double d = 0.0d;
        for (Map.Entry entry : map.entrySet()) {
            double[] dArr = (double[]) entry.getKey();
            double[] dArr2 = (double[]) entry.getValue();
            double[] activate = this.net.activate(dArr);
            double d2 = 0.0d;
            for (int i2 = 0; i2 < activate.length; i2++) {
                d2 += Math.abs(dArr2[i2] - activate[i2]);
            }
            d += d2;
            if (i > 0) {
                System.out.println(new StringBuffer().append("Input ").append(arrayToString(dArr)).append("  Target ").append(arrayToString(dArr2)).append("  Actual ").append(arrayToString(activate)).append("  Error ").append(d2).toString());
            }
        }
        if (i > 0) {
            System.out.println(new StringBuffer().append("Total error ").append(d).toString());
        }
        return d;
    }

    private static String arrayToString(double[] dArr) {
        StringBuffer stringBuffer = new StringBuffer();
        for (double d : dArr) {
            stringBuffer.append(new StringBuffer().append(d).append(" ").toString());
        }
        return stringBuffer.toString();
    }
}
