Chris Pollett > Students >
Leo

    ( Print View )

    [Bio]

    [Project Blog]

    [CS297Proposal]

    [Del1]

    [Del2]

    [Papers Slides-PDF]

    [Del4]

    [CS297Report-PDF]

    [CS299Proposal]

    [CS299Report-PDF]

    [CS299Presentation-PDF]

    [Grad Photo1-JPG]

    [Grad Photo2-JPG]

    [Grad Photo3-JPG]

                          

























Word Prediction Using Hidden Markov Model

Description: This deliverable uses a Hidden Markov Model, to predict the next move a player in Alpha Figther might make. Taking the idea from Statistical Language Learning by Eugene Charniak (Chapter 3), one can extend the n-gram with an HMM. This is to overcome a major problem with n-grams, which is that with higher n values, the model can become too specific. If the model happens to come across something not seen from the training corpus, it will assign a probability of 0. However, using low values of n might make it too general and not provide the accuracy of higher n values. What the HMM provides is a mechanism to fallback on more general n-grams if the higher n values are suffering from this overfitting problem. This provides the accuracy of using a high n, while allowing for fallback to lower n values if neccessary. This particular program keeps track of a unigram, bigram, and trigram. Each of these n-grams has an associated weight with higher n values getting higher weights. Therefore, the highest n-gram (trigram) will "win" if it is useful, but falls back on the lower n-grams if it's suffering from overfitting.

Example:This is the move prediction applet.

HMM applet<>

The text area on the left displays the currently trained model. It is displayed as a tree. The number next to the move code is the probability of that move at that level. For example, the number at the first column is the unigram probability of those particular moves, second column the bigram, and etc. Training takes place automatically as the user inputs moves.

To input a move, type a valid move character in the "Input" text field. Valid move codes are the characters 'o', 'p', 'l', and ';' (quotes for clarity). Case does not matter. Their meanings are weak-high attack, strong-high attack, weak-low attack, and strong-low attack respectively. Once the user inputs a move, the new move will be incorporated into the the trained model. The new prediction of the most likely next move will be updated under the label, "Prediction." The user can see the last two moves entered under "History," with the top one being older. At any time the user can clear the model of all training by clicking on the "Clear Model" button.


--------------------------------- HmmApplet class -----------------------------
import java.applet.Applet;
import java.awt.Component;
import java.awt.Container;
import java.awt.Dimension;
import java.awt.Font;
import java.awt.Frame;
import java.awt.Insets;
import java.awt.LayoutManager;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.KeyEvent;
import java.awt.event.KeyListener;
import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;

import javax.swing.ImageIcon;
import javax.swing.JButton;
import javax.swing.JLabel;
import javax.swing.JScrollPane;
import javax.swing.JTextArea;
import javax.swing.JTextField;

/**
 * Copyright (c) 2004 Leo Lee. All Rights Reserved.
 * Most of the GUI code was generated using Java GUI Builder v1.3.
 *
 * Applet that uses an HMM, which is a consolidation of a unigram, bigram,
 * and trigram models, to predict the next move in a sequence of moves.
 * Conceptually, given the previous few words, the HMM finds the probability
 * of the next move using various n-gram models.  It weighs the prediction of
 * each model by a factor, giving more weight to higher n values.  This allows
 * the system to fall back to a lower, but more general model if a more
 * specific model fails to find a match.
 *
 * Note: This is a modification of the HMM Word Predictor.  It was modified
 * to put the program in the context of the game to be developed, "Alpha Fighter."
 *
 * @author Leo Lee
 * @version 1.0 11/10/2004
 */
public class HmmApplet extends Applet implements ActionListener, KeyListener
{
  public static void main(String args[])
  {
    HmmApplet applet = new HmmApplet();
    Frame window = new Frame("Move Prediction with Hidden Markov Model");

    window.addWindowListener(new WindowAdapter()
    {
      public void windowClosing(WindowEvent e)
      {
        System.exit(0);
      }
    });

    applet.init();
    window.add("Center", applet);
    window.pack();
    window.setVisible(true);
  }

  /**
   * Initialize the applet by placing the various GUI components
   * and instantiating classes.
   */
   public void init()
   {
    //  Initialize the HMM class
    m_hmm = new Hmm(this);

     // Setup the GUI
      move_guiLayout customLayout = new move_guiLayout();

      setFont(new Font("Helvetica", Font.PLAIN, 12));
      setLayout(customLayout);

    m_taModel = new JTextArea("Model will display here after input.");
    sp_m_taModel = new JScrollPane(m_taModel);
    add(sp_m_taModel);

    m_lbModel = new JLabel("Model");
    add(m_lbModel);

    m_lbHistory = new JLabel("History");
    add(m_lbHistory);

    m_btClear = new JButton("Clear Model");
    add(m_btClear);

    m_lbPic1 = new JLabel("Last Last Move");
    m_lbPic1.setIcon(new ImageIcon(getImageName(m_hmm.getLastMove(1))));
    add(m_lbPic1);

    m_lbPrediction = new JLabel("Prediction");
    add(m_lbPrediction);

    m_lbPic2 = new JLabel("Last Move");
    m_lbPic2.setIcon(new ImageIcon(getImageName(m_hmm.getLastMove(0))));
    add(m_lbPic2);

    m_lbPic3 = new JLabel("Predicted Move");
    m_lbPic3.setIcon(new ImageIcon(getImageName(m_hmm.getPrediction())));
    add(m_lbPic3);

    m_lbInput = new JLabel("Input:");
    add(m_lbInput);

    m_tfInput = new JTextField("");
    add(m_tfInput);

    setSize(getPreferredSize());

      // Add listeners
      m_btClear.addActionListener(this);
      m_tfInput.addKeyListener(this);
   }

   /**
    * Refreshes the display in the model output textarea.
    * This method is called by the Hmm class whenever it has changed.
    */
   public void updateView()
   {
     m_taModel.setText(m_hmm.toString());
     m_lbPic1.setIcon(new ImageIcon(getImageName(m_hmm.getLastMove(1))));
    m_lbPic2.setIcon(new ImageIcon(getImageName(m_hmm.getLastMove(0))));
    m_lbPic3.setIcon(new ImageIcon(getImageName(m_hmm.getPrediction())));
   }

  private String getImageName(char moveCode)
  {
    if ( moveCode == Hmm.STRONG_HIGH )
    {
      return STRONG_HIGH_IMAGENAME;
    } else if ( moveCode == Hmm.STRONG_LOW )
    {
      return STRONG_LOW_IMAGENAME;
    } else if ( moveCode == Hmm.WEAK_HIGH )
    {
      return WEAK_HIGH_IMAGENAME;
    } else if (moveCode == Hmm.WEAK_LOW )
    {
      return WEAK_LOW_IMAGENAME;
    } else
    {
      return UNKNOWN_IMAGENAME;
    }
  }

  //---------------------------------- Listeners --------------------------//

  /**
   * Processes the user inputs.
   * Clicking "Train" button trains HMM with current text in taInput.
   * Clicking "Clear Model" button clears any previous training of HMM.
   */
  public void actionPerformed(ActionEvent event)
  {
    Object source = event.getSource();
    if ( source.equals(m_btClear) )
    {
      m_hmm.clear();
    }
  }

  public void keyPressed(KeyEvent event)
  {
    char key = event.getKeyChar();
    m_hmm.train(key);
    m_hmm.predict();
    updateView();
  }

  public void keyReleased(KeyEvent event)
  {
  }

  public void keyTyped(KeyEvent event)
  {
  }

  // Constants
  private static final String STRONG_HIGH_IMAGENAME = "strong_high.jpg";
  private static final String WEAK_HIGH_IMAGENAME = "weak_high.jpg";
  private static final String STRONG_LOW_IMAGENAME = "strong_low.jpg";
  private static final String WEAK_LOW_IMAGENAME = "weak_low.jpg";
  private static final String UNKNOWN_IMAGENAME = "unknown.jpg";

   // Fields
    private Hmm m_hmm;

   // GUI components
  private JTextArea m_taModel;
  private JScrollPane sp_m_taModel;
  private JLabel m_lbModel;
  private JLabel m_lbHistory;
  private JButton m_btClear;
  private JLabel m_lbPic1;
  private JLabel m_lbPrediction;
  private JLabel m_lbPic2;
  private JLabel m_lbPic3;
  private JLabel m_lbInput;
  private JTextField m_tfInput;
}

//
// Contains the Swing GUI for the applet.  Generated with "Java GUI Builder v1.3"
//
class move_guiLayout implements LayoutManager {

   public move_guiLayout() {
   }

   public void addLayoutComponent(String name, Component comp) {
   }

   public void removeLayoutComponent(Component comp) {
   }

   public Dimension preferredLayoutSize(Container parent) {
      Dimension dim = new Dimension(0, 0);

      Insets insets = parent.getInsets();
      dim.width = 659 + insets.left + insets.right;
      dim.height = 534 + insets.top + insets.bottom;

      return dim;
   }

   public Dimension minimumLayoutSize(Container parent) {
      Dimension dim = new Dimension(0, 0);
      return dim;
   }

  public void layoutContainer(Container parent) {
     Insets insets = parent.getInsets();

     Component c;
     c = parent.getComponent(0);
     if (c.isVisible()) {c.setBounds(insets.left+8,insets.top+32,312,456);}
     c = parent.getComponent(1);
     if (c.isVisible()) {c.setBounds(insets.left+80,insets.top+8,176,16);}
     c = parent.getComponent(2);
     if (c.isVisible()) {c.setBounds(insets.left+376,insets.top+8,128,16);}
     c = parent.getComponent(3);
     if (c.isVisible()) {c.setBounds(insets.left+120,insets.top+496,104,24);}
     c = parent.getComponent(4);
     if (c.isVisible()) {c.setBounds(insets.left+336,insets.top+32,208,136);}
     c = parent.getComponent(5);
     if (c.isVisible()) {c.setBounds(insets.left+360,insets.top+328,160,16);}
     c = parent.getComponent(6);
     if (c.isVisible()) {c.setBounds(insets.left+336,insets.top+176,208,136);}
     c = parent.getComponent(7);
     if (c.isVisible()) {c.setBounds(insets.left+336,insets.top+352,208,136);}
     c = parent.getComponent(8);
     if (c.isVisible()) {c.setBounds(insets.left+336,insets.top+496,72,24);}
     c = parent.getComponent(9);
     if (c.isVisible()) {c.setBounds(insets.left+424,insets.top+496,72,24);}
  }
}

------------------------------- Hmm.java -----------------------------

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.ListIterator;

/**
 * Copyright (c) 2004 Leo Lee. All Rights Reserved.
 * Hidden Markov Model represented as a tree.  The HMM incorporates
 * a unigram, bigram and trigram, weighing the contributions of each.
 * Not counting the root node, the first level gives the unigram probabilities,
 * second level the bigram, and third level the trigram.
 * @author Leo Lee
 * @version 1.0 11/10/2004
 */
public class Hmm
{

  /**
   * Constructs an empty Hidden Markov Model tree structure.
   */
  public Hmm(HmmApplet applet)
  {
    m_applet = applet;
    m_root = new Node(' ', null);
    m_weights = new float[MAX_N];
    m_weights[0] = DEFAULT_UNIWEIGHT;
    m_weights[1] = DEFAULT_BIWEIGHT;
    m_weights[2] = DEFAULT_TRIWEIGHT;
    m_lastMoves = new LinkedList();
  }

  /**
   * Trains the HMM with the given move. The previous two moves (if any) are
   * kept as a side effect of this function.  The given move adds to the
   * unigram, bigram and trigram probabilities of it, in the context of the
   * previous moves for the latter two n-grams.
   * Note that this adds to the exsiting model.  The clear() method
   * should be called if a fresh, untrained model is desired.
   * @param move The new move observed.
   */
  public void train(char move)
  {
    // Make sure it's a valid move
    if ( !isMove(move) ) return;

    // Make the 3-word tuple and update model with them
    m_lastMoves.addLast(new Character(move));
    updateModel(m_lastMoves);

    // Only keep the last 2 moves
    if ( m_lastMoves.size() > MAX_N - 1 )
    {
      m_lastMoves.removeFirst();
    }
  }

  /**
   * Predicts the most probable next move based on the current model,
   * and recent last moves.
   * @param prevWords Previous words to base prediction upon.
   */
  public void predict()
  {
    // Do the prediction for unigram, bigram, and trigram
    //models weighed by their respective weights
    Node curParent = null;
    float probs[] = new float[MAX_N];
    // highest probability for respective n
    //for corresponding move in preds
    char preds[] = new char[MAX_N];
    // most probable move for respective n, null if non-existent
    int n = 1; // 1 = unigram, 2 = bigram, etc.
    // Do the next n-gram as long as n <= our desired max and
    //we have at least n - 1 words in the words list
    while ( n <= MAX_N && m_lastMoves.size() >= n - 1 )
    {
      // Find the appropriate node in the model
      ListIterator iterator = m_lastMoves.listIterator(m_lastMoves.size() - n + 1);
      curParent = traverse(iterator);

      // Compute the most probable child if possible
      if ( curParent != null )
      {
        ArrayList children = curParent.getChildren();

        int maxCount = 0;
        for ( int childIndex = 0; childIndex < children.size(); childIndex++ )
        {
          Node child = (Node)children.get(childIndex);
          if ( child.getCount() > maxCount )
          {
            maxCount = child.getCount();
            preds[n - 1] = child.getMoveCode();
          }
        }

        // Compute probability of highest count child, prob = 0.0 if no children
        if ( curParent.getChildrenCount() != 0 )
        {
          probs[n - 1] = (float)maxCount / curParent.getChildrenCount();
        }
      }

      //  Go to next n-gram
      n++;
    }

    // Find the n-gram with the most probably word weighed by the n-gram's weight
    int maxIndex = 0;
    for ( int i = 1; i < MAX_N; i++ )
    {
      if ( probs[i] * m_weights[i]
        > probs[maxIndex] * m_weights[maxIndex] ) maxIndex = i;
    }
    if ( preds[maxIndex] == 0 ) m_prediction = 0;
    else m_prediction = preds[maxIndex];
  }

  /**
   * Starting from the root of the tree and head of the given iterator,
   * use each element in iterator to traverse one level down the tree.
   * If for any reason cannot traverse all the way through to the end of the
   * iterator, null is returned.
   * @param iterator Elements used to traverse down the tree.
   * @return Node of the tree reached at the
   * end of the traversal. Null if traversal fails.
   */
  public Node traverse(ListIterator iterator)
  {
    Node curNode = m_root;
    while ( iterator.hasNext() )
    {
      Node tmpNode = new Node(((Character)iterator.next()).charValue(), curNode);
      int nextNodeIndex = curNode.getChildren().indexOf(tmpNode);
      if ( nextNodeIndex == -1 ) return null; // dead end in traversal
      curNode = (Node)curNode.getChildren().get(nextNodeIndex);
    }
    return curNode; // success in traversal
  }

  /**
   * Clears the HMM so any previous training is wiped out.
   */
  public void clear()
  {
    m_root = new Node(' ', null);
    m_applet.updateView();
  }

  /**
   * String representation of the model.
   * @return The string representation.
   */
  public String toString()
  {
    StringBuffer result = new StringBuffer();
    ArrayList rootChildren = m_root.getChildren();
    for ( int i = 0; i < rootChildren.size(); i++)
    {
      result.append(toStringHelp((Node)rootChildren.get(i), 0));
    }
    return result.toString();
  }

  public String toStringHelp(Node curNode, int level)
  {
    StringBuffer output = new StringBuffer();

    // First put myself in the output
    for( int i = 0; i < level * INDENT; i++ )
    {
      output.append(" ");
    }
    output.append(curNode.toString() + "\n");

    // Let my children go in output recursively
    ArrayList children = curNode.getChildren();
    for ( int i = 0; i < children.size(); i++ )
    {
      output.append(toStringHelp((Node)children.get(i), level + 1));
    }
    return output.toString();
  }

  /**
   * Iterates through the given words from head to tail, starting from root of tree,
   * for each word:
   * If it is not in this level of the tree add it as a node.
   * In either case, update the count of this node.
   * Move down a level, go to next word and repeat.
   * @param words List of words to add/update.
   */
  private void updateModel(LinkedList words)
  {
    Node curParent = m_root;
    ListIterator it = words.listIterator();
    while ( it.hasNext() )
    {
      Node curChild = new Node(((Character)it.next()).charValue(), curParent);
      int existingAt = curParent.getChildren().indexOf(curChild);
      if ( existingAt == -1 )
      {  // Add new node if not existing yet
        curParent.addChild(curChild);
      } else
      {  // Set curChild to existing child node
        curChild = (Node)curParent.getChildren().get(existingAt);
      }

      // Update the count for this node
      curChild.incCount();

      // Traverse down the tree
      curParent = curChild;
    }
  }

  /**
   * Gets the previous move index number back.
   * For example, 0 = last move, 1 = move before that, etc.
   * @param index Number to go back.
   * @return Previous move based on index as described.
   */
  public char getLastMove(int index)
  {
    if ( index < m_lastMoves.size() )
    {
      return ((Character)m_lastMoves.get(m_lastMoves.size() - index - 1)).charValue();
    }
    return 0; // index not in range
  }

  /**
   * Returns the current prediction of next move.
   * @return Current next move prediction.
   */
  public char getPrediction()
  {
    return m_prediction;
  }

  public boolean isMove(char move)
  {
    return move == WEAK_HIGH || move == STRONG_HIGH
      || move == WEAK_LOW || move == STRONG_LOW;
  }

  // Constants
  // Allowed moves
  public static final char WEAK_HIGH = 'o';
  public static final char STRONG_HIGH = 'p';
  public static final char WEAK_LOW = 'l';
  public static final char STRONG_LOW = ';';

  private final static int INDENT = 10;
  private final static String WORD_DELIMITERS = " \t\n.,?!();:";
  private final static int MAX_N = 3;
   // as in the N in n-gram.  3 == use unigram, bigram and trigram.
  // Note: all weights must sum to 1
  private final static float DEFAULT_UNIWEIGHT = 0.1f;
  private final static float DEFAULT_BIWEIGHT = 0.3f;
  private final static float DEFAULT_TRIWEIGHT =
   1.0f - DEFAULT_UNIWEIGHT - DEFAULT_BIWEIGHT;

  // Fields
  private HmmApplet m_applet;
  private Node m_root;
  private float m_weights[];
  private LinkedList m_lastMoves;
  private char m_prediction;
}

-------------------------------- Node.java -----------------------------

import java.util.ArrayList;

/**
 * Copyright (c) 2004 Leo Lee. All Rights Reserved.
 * A node in the tree structure used to represent the Hidden Markov Model.
 * Node contains the current move and the statistical probability
 * of going to this node from its parent.  In other words,
 * P(this node | parent node).
 * @author Leo Lee
 * @version 1.0 11/10/2004
 */
public class Node
{

  /**
   * Constructor.
   * @param moveCode Code for move of this node.
   * 'O' = weak high attack
   * 'P' = strong high attack
   * 'L' = weak low attack
   * ';' = strong low attack
   * @param parent Parent node.
   */
  public Node(char moveCode, Node parent)
  {
    m_moveCode = Character.toLowerCase(moveCode);
    m_count = 0;
    m_childrenCount = 0;
    m_children = new ArrayList();
    m_parent = parent;
  }

  //---------------------------- overrides -------------------------------//

  /**
   * String representation of the node, which is of the form: moveCode(count)
   * @return The string representation.
   */
  public String toString()
  {
    return m_moveCode + "( " + m_count + " )";
  }

  /**
   * Equals if given object is a Node, contains same move code,
   * and have the same parent node.
   * @param other Other object to compare this one to.
   * @return true if this and other are equal, otherwise false.
   */
  public boolean equals(Object other)
  {
    if ( other instanceof Node )
    {
      Node otherNode = (Node)other;
      return m_moveCode ==
         otherNode.m_moveCode && m_parent == otherNode.m_parent;
    }
    return false;
  }

  //---------------------------- mutators --------------------------------//

  /**
   * Increment the count for this node by 1. Also will increment
   * the parent's sum of children's count by 1.
   */
  public void incCount()
  {
    m_count++;
    m_parent.m_childrenCount++;
  }

  /**
   * Adds the given node as a child to this one.
   * @param child Child node to add.
   */
  public void addChild(Node child)
  {
    m_children.add(child);
  }

  //---------------------------- accessors -------------------------------//

  /**
   * Gets the move code for this node.
   * @return The move code for this node.
   */
  public char getMoveCode()
  {
    return m_moveCode;
  }

  /**
   * Gets the count of going from the parent to this child.
   * @return The count.
   */
  public int getCount()
  {
    return m_count;
  }

  /**
   * Gets the sum of all my children's count.
   * @return Total children's count.
   */
  public long getChildrenCount()
  {
    return m_childrenCount;
  }

  /**
   * Gets children nodes of this one.
   * @return Children nodes.
   */
  public ArrayList getChildren()
  {
    return m_children;
  }

  // Fields
  private char m_moveCode;
  private int m_count;
  private long m_childrenCount;
  private ArrayList m_children;
  private Node m_parent;
}