package nets;
import java.util.*;

/**
 * A class that implements the back-propagation learning algorithm.
 * @author Derek Bridge
 */
public class BackProp
{
/* =======================================================================
       CONSTRUCTORS
   =======================================================================
*/

   /**
    * Allocates a new learner for a net having the given dimensions.
    * @param theNumOfInputs the number of inputs.
    * @param theNumOfHiddenUnits the number of hidden units.
    * @param theNumOfOutputs the number of output units.
    * (These figures exclude `extra' units used to replace thresholds.)
    */
   public BackProp(int theNumOfInputs, int theNumOfHiddenUnits,
      int theNumOfOutputs)
   {  net = new TwoLayerNet(theNumOfInputs, theNumOfHiddenUnits, 
         theNumOfOutputs);
      numOfInputs = theNumOfInputs + 1;
      numOfHiddenUnits = theNumOfHiddenUnits + 1;
      numOfOutputs = theNumOfOutputs;
   }

/* =======================================================================
       PUBLIC INTERFACE
   =======================================================================
*/

/* --Setters----------------------------------------------------------- */

   /**
    * Invokes the back-propagation learning algorithm for the given
    * number of epochs, learning rate and examples.
    * @param theNumOfEpochs the number of epochs.
    * @param theLearningRate the learning rate.
    * @param theTrainingSet a `map' (e.g. hashtable) in which inputs
    * are paired with their corresponding target outputs. (The inputs
    * are arrays of sensor values, one per input unit (sensor); 
    * the target outputs are arrays of values, one per output unit).
    * @param theDebugLevel an integer used to control how much info to
    * output. (If == 0, then no output; if == 1, then output start and
    * end nets; if == 2, then output start net and net at end of each
    * epoch; if == 3, then output start net and net at end of each example.)
    * @param theTestSet a `map' (e.g. hashtable) in which inputs are
    * paired with their corresponding target outputs.
    * @theTestLevel an integer used to control how often the net should
    * be tested. (If == 0, it won't be tested at all and theTestSet can be
    * null; if == 1, it will be tested only after all training (all epochs);
    * if == 2, it will be tested at the end of each epoch; if == 3, it will
    * be tested after each example.)
    * @return the learned net.
    */
   public TwoLayerNet backProp(int theNumOfEpochs, 
      double theLearningRate, Map theTrainingSet, int theDebugLevel,
      Map theTestSet, int theTestLevel)
   {  if (theDebugLevel > 0)
      {  System.out.println(net);
      }
      /* Train the net for the given number of epochs.
       */
      for (int e = 0; e < theNumOfEpochs; e++)
      {  backPropOneEpoch(theLearningRate, theTrainingSet,
            theDebugLevel, theTestSet, theTestLevel);
      } 
      if (theDebugLevel == 1)
      {  System.out.println(net);
      }
      if (theTestLevel == 1)
      {  test(theTestSet, theDebugLevel);
      }
      return net;
   }

   /**
    * Invokes the back-propagation learning algorithm for the given
    * learning rate and examples until the net converges (total error
    * is less than or equal to the given allowable error).
    * @param theAllowableError the amount of error we will tolerate. The
    * net has converged when its error equals or falls below this.
    * @param theLearningRate the learning rate.
    * @param theTrainingSet a `map' (e.g. hashtable) in which inputs
    * are paired with their corresponding target outputs. (The inputs
    * are arrays of sensor values, one per input unit (sensor); 
    * the target outputs are arrays of values, one per output unit).
    * @param theDebugLevel an integer used to control how much info to
    * output. 
    * @param theTestSet a `map' (e.g. hashtable) in which inputs are
    * paired with their corresponding target outputs.
    * @theTestLevel an integer used to control how often the net should
    * be tested.
    * @return the learned net.
    */
   public TwoLayerNet backProp(double theAllowableError,
      double theLearningRate, Map theTrainingSet, int theDebugLevel,
      Map theTestSet, int theTestLevel)
   {  if (theDebugLevel > 0)
      {  System.out.println(net);
      }
      /* Train the net until its error is tolerable.
       */
      int e = 0;
      double totalError = 0.0;
      do      
      {  backPropOneEpoch(theLearningRate, theTrainingSet,
            theDebugLevel, theTestSet, theTestLevel);
         /* After this epoch, how much error are we getting
            on the training set. (Doing this separately is
            an inefficiency in this program.)
          */
         e++;
         totalError = test(theTrainingSet, 0); // debug level of 0 = no output
      } while (totalError > theAllowableError);
      if (theDebugLevel == 1)
      {  System.out.println("Training converged after " + e + " epochs.\n");
         System.out.println(net);
      }
      if (theTestLevel == 1)
      {  test(theTestSet, theDebugLevel);
      }
      return net;
   }

   /**
    * Invokes the back-propagation learning algorithm for one epoch.
    * @param theLearningRate the learning rate.
    * @param theTrainingSet a `map' (e.g. hashtable) in which inputs
    * are paired with their corresponding target outputs. (The inputs
    * are arrays of sensor values, one per input unit (sensor); 
    * the target outputs are arrays of values, one per output unit).
    * @param theDebugLevel an integer used to control how much info to
    * output.
    * @param theTestSet a `map' (e.g. hashtable) in which inputs are
    * paired with their corresponding target outputs.
    * @theTestLevel an integer used to control how often the net should
    * be tested.
    * @return the learned net.
    */
   public TwoLayerNet backPropOneEpoch(double theLearningRate, 
      Map theTrainingSet, int theDebugLevel, Map theTestSet,
      int theTestLevel)
   {  Set examples = theTrainingSet.entrySet();
      Iterator iter = examples.iterator();
      double[] exampleInput;
      double[] targetOutput;
      /* Train the net on each example.
       */
      while (iter.hasNext())
      {  Map.Entry example = (Map.Entry) iter.next();
         exampleInput = (double[]) example.getKey();
         targetOutput = (double[]) example.getValue();
         backPropOneExample(theLearningRate, exampleInput, targetOutput, 
            theDebugLevel, theTestSet, theTestLevel);
      }
      if (theDebugLevel == 2)
      {  System.out.println(net);
      }
      if (theTestLevel == 2)
      {  test(theTestSet, theDebugLevel);
      }
      return net;
   }

   /**
    * Invokes the back-propagation learning algorithm for one example.
    * @param theLearningRate the learning rate.
    * @param exampleInput the sensor values.
    * @param targetOutput the corresponding deired output values.
    * @param theDebugLevel an integer used to control how much info to
    * output.
    * @param theTestSet a `map' (e.g. hashtable) in which inputs are
    * paired with their corresponding target outputs.
    * @theTestLevel an integer used to control how often the net should
    * be tested.
    * @return the learned net.
    */
   public TwoLayerNet backPropOneExample(double theLearningRate, 
      double[] exampleInput, double[] targetOutput, 
      int theDebugLevel, Map theTestSet, int theTestLevel)
   {  double[] output;
      double[] err = new double[numOfOutputs];
      OutputLayerTLU outputUnit;
      double[] deltaO = new double[numOfOutputs];
      HiddenLayerTLU hiddenUnit;
      double[] deltaH = new double[numOfHiddenUnits];
      /* Compute the actual output we get for this example input.
       */
      output = net.activate(exampleInput);
      /* Compute error at output layer and update weights
         between hidden layer and output layer.
       */
      for (int k = 0; k < numOfOutputs; k++)
      {  err[k] = targetOutput[k] - output[k];
         outputUnit = (OutputLayerTLU) net.getOutputUnit(k);
         deltaO[k] = outputUnit.getGradient() * err[k];
         /* Adjust each incoming weight.
          */
         for (int j = 0; j < numOfHiddenUnits; j++)
         {  net.setOutputLayerWeight(j, k,
               net.getOutputLayerWeight(j, k) +
               theLearningRate * net.getHiddenUnit(j).getOutput() * 
               deltaO[k]);
            /* The weight times deltaO can be added into deltaH for
               use below.
             */
           deltaH[j] += net.getOutputLayerWeight(j, k) * deltaO[k];
         }
      }
      /* Compute error at hidden layer and update weights
         between input layer and hidden layer.
       */
      for (int j = 1; j < numOfHiddenUnits; j++)
      {  /* The sum of the weights and the output layer delta terms
            are already in deltaH. So just multiply by the gradient.
          */
         hiddenUnit = (HiddenLayerTLU) net.getHiddenUnit(j);
         deltaH[j] *= hiddenUnit.getGradient();
         /* Adjust each incoming weight.
          */
         net.setHiddenLayerWeight(0, j,
               net.getHiddenLayerWeight(0, j) +
               theLearningRate * 1.0 * deltaH[j]);
         for (int i = 1; i < numOfInputs; i++)
         {  net.setHiddenLayerWeight(i, j,
               net.getHiddenLayerWeight(i, j) +
               theLearningRate * exampleInput[i - 1] * deltaH[j]);
         }
      } 
      if (theDebugLevel == 3)
      {  System.out.println(net);
      }
      if (theTestLevel == 3)
      {  test(theTestSet, theDebugLevel);
      }
      return net;
   }

   /**
    * Tests a net on a set of examples. 
    * @param theTestSet the examples to use for testing.
    * @param theDebugLevel an integer used to control how much info to
    * output.
    * @return the total error this net produces for these examples.
    */
   public double test(Map theTestSet, int theDebugLevel)
   {  double[] exampleInput;
      double[] targetOutput;
      double[] output;
      double exampleError;
      double totalError = 0.0;
      Set examples = theTestSet.entrySet();
      Iterator iter = examples.iterator();
      while (iter.hasNext())
      {  Map.Entry example = (Map.Entry) iter.next();
         exampleInput = (double[]) example.getKey();
         targetOutput = (double[]) example.getValue();
         output = net.activate(exampleInput);
         exampleError = 0.0;
         for (int i = 0; i < output.length; i++)
         {  exampleError += Math.abs(targetOutput[i] - output[i]);
         }
         totalError += exampleError;
         if (theDebugLevel > 0)
         {  System.out.println("Input " + arrayToString(exampleInput) +
              "  Target " + arrayToString(targetOutput) +
              "  Actual " + arrayToString(output) +
              "  Error " + exampleError);
         }
      }
      if (theDebugLevel > 0)
      {  System.out.println("Total error " + totalError);
      }
      return totalError;
   }

/* =======================================================================
       HELPER METHODS
   =======================================================================
*/

   /**
    * Turns an array of doubles into a String.
    * @param a the array.
    */
   private static String arrayToString(double[] a)
   {  StringBuffer sb = new StringBuffer();
      for (int i = 0; i < a.length; i++)
      {  sb.append(a[i] + " ");
      }
      return sb.toString();
   }
  
/* =======================================================================
       INSTANCE VARIABLES & CLASS VARIABLES
   =======================================================================
*/

   private TwoLayerNet net;
   private int numOfInputs;
   private int numOfHiddenUnits;
   private int numOfOutputs;
}
