University of Connecticut University of UC Title Fallback Connecticut

The TreeUpdater class

The TreeUpdater class is responsible for updating a tree. Updating a tree involves proposing a change to the tree topology and edge lengths and then deciding whether to accept the proposed modification.

Create a new file tree_updater.hpp and replace the default contents with the following class declaration.

#pragma once

#include "tree.hpp"
#include "lot.hpp"
#include "likelihood.hpp"
#include "tree_manip.hpp"

namespace strom
    {

    template<class T> class Chain;

    template <class T>
    class TreeUpdater
        {
        friend class Chain<T>;

        public:

            typedef std::pair<double, double>   kernel_t;

                                                TreeUpdater();
                                                ~TreeUpdater();
            
            void                                setLikelihood(typename Likelihood<T>::SharedPtr likelihood) {_likelihood = likelihood;}
            void                                setTreeManip(typename TreeManip<T>::SharedPtr treemanip) {_tree_manipulator = treemanip;}
            void                                setLot(Lot::SharedPtr lot) {_lot = lot;}
            void                                setLambda(double lambda) {_lambda = lambda;}
            void                                setTuning(bool on);
            void                                setTargetAcceptanceRate(double target) {_target_acceptance = target;}

            void                                clear();
            kernel_t                            calcLogPosteriorKernel() const;
            kernel_t                            update();
            
        private:
        
            double                              calcLogTopologyPrior() const;
            double                              calcEdgeLengthPrior() const;
            
            void                                proposeNewState();
            void                                revert();
            void                                reset();
            void                                tune(bool accepted);

            Lot::SharedPtr                      _lot;
            typename Likelihood<T>::SharedPtr   _likelihood;
            typename TreeManip<T>::SharedPtr    _tree_manipulator;

            double                              _orig_edgelen_top;
            double                              _orig_edgelen_middle;
            double                              _orig_edgelen_bottom;
            
            double                              _new_edgelen_top;
            double                              _new_edgelen_middle;
            double                              _new_edgelen_bottom;
            
            double                              _lambda;
            double                              _log_hastings_ratio;
            unsigned                            _case;
            bool                                _topology_changed;
            T *                                 _x;
            T *                                 _y;
            T *                                 _a;
            T *                                 _b;
            T *                                 _c;
            T *                                 _d;
            double                              _smallest_edge_length;
            double                              _target_acceptance;
            unsigned                            _naccepts;
            unsigned                            _nattempts;
            bool                                _tuning;
            kernel_t                            _prev_logkernel;

        public:
            typedef boost::shared_ptr< TreeUpdater<T> > SharedPtr;
        };
 
// member function bodies go here
}

Constructor and destructor

The constructor delegates all its initialization work to the clear() function and the destructor does nothing.

template <class T>
inline TreeUpdater<T>::TreeUpdater()
    {
    clear();
    }

template <class T>
inline TreeUpdater<T>::~TreeUpdater()
    {
    }

The clear() and reset() functions

Together, clear() and reset() initialize everything. The clear() function can be used to return the TreeUpdater class to its just-constructed state, while reset() is used to get the TreeUpdater ready for the next call to update. Note that reset() avoids wiping out quantities that are needed for subsequent updates, such as _prev_logkernel, which stores the current log likelihood and log prior values.

template <class T>
inline void TreeUpdater<T>::clear()
    {
    _tuning                 = true;
    _smallest_edge_length   = 1.e-12;
    _lambda                 = 0.2;
    _target_acceptance      = 0.15;
    _naccepts               = 0;
    _nattempts              = 0;
    _prev_logkernel         = std::make_pair(0.0, 0.0);
    reset();
    }

template <class T>
inline void TreeUpdater<T>::reset()
    {
    _topology_changed       = false;
    _orig_edgelen_top       = 0.0;
    _orig_edgelen_middle    = 0.0;
    _orig_edgelen_bottom    = 0.0;
    _new_edgelen_top        = 0.0;
    _new_edgelen_middle     = 0.0;
    _new_edgelen_bottom     = 0.0;
    _log_hastings_ratio     = 0.0;
    _case                   = 0;
    _x                      = 0;
    _y                      = 0;
    _a                      = 0;
    _b                      = 0;
    _c                      = 0;
    _d                      = 0;
    }

Functions related to tuning

The tune function adjusts the tuning parameter of this Metropolis-Hastings updater. Calling this function after each update during the burnin period, passing in true if proposal was accepted and false otherwise, will gradually adapt the sampler to get as close as possible to the target (user-supplied) acceptance rate. Because the tree topology changes are limited in their boldness (NNI swaps are the only possible topology changes), it is not possible to adjust this particular proposal to any desired target value, but this function serves to get as close as possible to the target.

The setTuning function is used to turn tuning on or off. Call setTuning(true) to tune the updater. After a suitable burn-in period, turn off tuning by calling setTuning(false) in order to draw valid posterior samples in future updates.

The tuning algorithm used here is described in this paper:

Prokaj, Vilmos. 2009. Proposal selection for MCMC simulation. pp. 61-65 in Sakalauskas L., C. Skiadas and E. K. Zavadskas (eds.), Applied stochastic models and data analysis, the XIII International conference (ASMDA-2009), June 30-July 3, 2009, Vilnius, Lithuania.

template <class T>
void TreeUpdater<T>::setTuning(bool on)
    {
    _tuning = on;
    if (!on)
        {
        _nattempts = 0;
        }
    }

template <class T>
inline void TreeUpdater<T>::tune(bool accepted)
    {
    if (_tuning)
        {
        double gamma_n = 10.0/(100.0 + (double)_nattempts);
        if (accepted)
            _lambda *= 1.0 + gamma_n*(1.0 - _target_acceptance)/(2.0*_target_acceptance);
        else
            _lambda *= 1.0 - gamma_n*0.5;
        }
    }

The calcLogTopologyPrior function

This function calculates the tree topology prior. This function actually needs to be called only once because all tree topologies have the same prior probability under the discrete uniform prior assumed here; however, we will compute it each time unless later profiling indicates that this function requires a significant amount of computation.

template <class T>
inline double TreeUpdater<T>::calcLogTopologyPrior() const
    {
    typename Tree<T>::SharedPtr tree = _tree_manipulator->getTree();
    assert(tree);
    double n = tree->numLeaves();
    if (tree->isRooted())
        n += 1.0;
    double log_num_topologies = lgamma(2.0*n - 5.0 + 1.0) - (n - 3.0)*log(2.0) - lgamma(n - 3.0 + 1.0);
    return -log_num_topologies;
    }

The calcEdgeLengthPrior() function

This function computes the edge length component of the prior. The prior distribution for each individual edge length is Exponential(10). This is not the best prior for edge lengths because it results in an implicit tree length prior that can have an unreasonably large mean and at the same time quite informative (the product of n independent Exponential(r) distributions is Gamma(n, 1/r), so the prior mean grows with the number of edges in the tree). Thus, this prior should be viewed as a temporary placeholder.

template <class T>
inline double TreeUpdater<T>::calcEdgeLengthPrior() const
    {
    typename Tree<T>::SharedPtr tree = _tree_manipulator->getTree();
    assert(tree);

    double rate = 10.0;
    double n = tree->numLeaves();
    double TL = _tree_manipulator->calcTreeLength();
    double num_edges = 2.0*n - (tree->isRooted() ? 2.0 : 3.0);
    double log_prior = num_edges*log(rate) - rate*TL;
    return log_prior;
    }

The calcLogPosteriorKernel function

This function simply calls calcLogLikelihood, calcLogTopologyPrior and calcEdgeLengthPrior to construct the log posterior kernel.

template <class T>
inline typename TreeUpdater<T>::kernel_t TreeUpdater<T>::calcLogPosteriorKernel() const
    {
    double lnL = _likelihood->calcLogLikelihood(_tree_manipulator->getTree());
    double lnP = calcLogTopologyPrior() + calcEdgeLengthPrior();
    return std::make_pair(lnL, lnP);
    }

The proposeNewState function

This is the largest and most complicated member function of the class. It is responsible for proposing a new tree state and doing the bookkeeping necessary to revert to the previous state if the proposed new state is not accepted. The proposal used here is the one proposed in the following papers:

Larget B., Simon D. 1999. Markov Chain Monte Carlo Algorithms for the Bayesian Analysis of Phylogenetic Trees. Molecular Biology and Evolution. 16:750–759.

Holder M.T., Lewis P.O., Swofford D.L., Larget B. 2005. Hastings ratio of the LOCAL proposal used in Bayesian phylogenetics. Systematic Biology. 54:961–965.

template <class T>
inline void TreeUpdater<T>::proposeNewState()
    {
    _case = 0;
    _topology_changed = false;
    
    // Choose random internal node x and let a and b equal its left and right children, d its sibling, y its parent, and
    // c y's parent. Thus, x and y are the vertices at the end of the chosen internal edge, a and b are attached to x,
    // and c and d are attached to y:
    //
    //  a     b
    //   \   /       
    //    \ /        
    //     x     d
    //      \   /    
    //       \ /     
    //        y
    //        |      
    //        |      
    //        c
    //
    _x = _tree_manipulator->randomInternalEdge(_lot->uniform());

    _a = _x->getLeftChild();
    _b = _a->getRightSib();
    _y = _x->getParent();
    _c = _y->getParent();
    _d = 0;
    if (_x == _y->getLeftChild())
        _d = _x->getRightSib();
    else
        _d = _y->getLeftChild();
        
    // Choose focal 3-edge segment to shrink or grow
    bool a_on_path = (_lot->uniform() < 0.5);
    if (a_on_path)
        _orig_edgelen_top = _a->getEdgeLength();
    else
        _orig_edgelen_top = _b->getEdgeLength();

    _orig_edgelen_middle = _x->getEdgeLength();

    bool c_on_path = (_lot->uniform() < 0.5);
    if (c_on_path)
        _orig_edgelen_bottom = _y->getEdgeLength();
    else
        _orig_edgelen_bottom = _d->getEdgeLength();
        
    double m = exp(_lambda*(_lot->uniform() - 0.5));
    _log_hastings_ratio = 3.0*log(m);
    
    _new_edgelen_top    = m*_orig_edgelen_top;
    _new_edgelen_middle = m*_orig_edgelen_middle;
    _new_edgelen_bottom = m*_orig_edgelen_bottom;

    // Decide where along focal path (starting from top) to place moved node
    double new_focal_path_length = _new_edgelen_top + _new_edgelen_middle + _new_edgelen_bottom;
    double u = _lot->uniform();
    double new_attachment_point = u*new_focal_path_length;
    if (new_attachment_point <= _smallest_edge_length)
        new_attachment_point = _smallest_edge_length;
    else if (new_focal_path_length - new_attachment_point <= _smallest_edge_length)
        new_attachment_point = new_focal_path_length - _smallest_edge_length;
    
    // Decide which node to move, and whether the move involves a topology change
    u = _lot->uniform();
    if (a_on_path && c_on_path)
        {
        if (u < 0.5)
            {
            _case = 1;
            
            // (a)    b*     
            //   \   /       
            //    \ /        
            //    (x)    d   
            //      \   /    
            //       \ /     
            //       (y)     
            //        |      
            //        |      
            //       (c)
            
            if (new_attachment_point > _new_edgelen_top + _new_edgelen_middle)
                {
                _topology_changed = true;
                _tree_manipulator->nniNodeSwap(_b, _d);
                _a->setEdgeLength(_new_edgelen_top + _new_edgelen_middle);
                _x->setEdgeLength(new_attachment_point - _a->getEdgeLength());
                _y->setEdgeLength(new_focal_path_length - new_attachment_point);
                }
            else
                {
                _topology_changed = false;
                _a->setEdgeLength(new_attachment_point);
                _x->setEdgeLength(_new_edgelen_top + _new_edgelen_middle - new_attachment_point);
                _y->setEdgeLength(_new_edgelen_bottom);
                }
            }
        else
            {
            _case = 2;
            
            // (a)    b     
            //   \   /      
            //    \ /       
            //    (x)    d* 
            //      \   /   
            //       \ /    
            //       (y)    
            //        |     
            //        |     
            //       (c)
            
            if (new_attachment_point < _new_edgelen_top)
                {
                _topology_changed = true;
                _tree_manipulator->nniNodeSwap(_b, _d);
                _a->setEdgeLength(new_attachment_point);
                _x->setEdgeLength(_new_edgelen_top - new_attachment_point);
                _y->setEdgeLength(_new_edgelen_middle + _new_edgelen_bottom);
                }
            else
                {
                _topology_changed = false;
                _a->setEdgeLength(_new_edgelen_top);
                _x->setEdgeLength(new_attachment_point - _new_edgelen_top);
                _y->setEdgeLength(new_focal_path_length - new_attachment_point);
                }
            }
        }
    else if (!a_on_path && c_on_path)
        {
        if (u < 0.5)
            {
            _case = 3;
            
            //  a*   (b)    
            //   \   /      
            //    \ /       
            //    (x)    d  
            //      \   /   
            //       \ /    
            //       (y)    
            //        |     
            //        |     
            //       (c)
            
            if (new_attachment_point > _new_edgelen_top + _new_edgelen_middle)
                {
                _topology_changed = true;
                _tree_manipulator->nniNodeSwap(_a, _d);
                _b->setEdgeLength(_new_edgelen_top + _new_edgelen_middle);
                _x->setEdgeLength(new_attachment_point - _b->getEdgeLength());
                _y->setEdgeLength(new_focal_path_length - new_attachment_point);
                }
            else
                {
                _topology_changed = false;
                _b->setEdgeLength(new_attachment_point);
                _x->setEdgeLength(_new_edgelen_top + _new_edgelen_middle - new_attachment_point);
                _y->setEdgeLength(_new_edgelen_bottom);
                }
            }
        else
            {
            _case = 4;

            //  a    (b)
            //   \   /      
            //    \ /       
            //    (x)    d*
            //      \   /   
            //       \ /    
            //       (y)    
            //        |     
            //        |     
            //       (c)

            if (new_attachment_point < _new_edgelen_top)
                {
                _topology_changed = true;
                _tree_manipulator->nniNodeSwap(_a, _d);
                _b->setEdgeLength(new_attachment_point);
                _x->setEdgeLength(_new_edgelen_top - new_attachment_point);
                _y->setEdgeLength(_new_edgelen_middle + _new_edgelen_bottom);
                }
            else
                {
                _topology_changed = false;
                _b->setEdgeLength(_new_edgelen_top);
                _x->setEdgeLength(new_attachment_point - _new_edgelen_top);
                _y->setEdgeLength(new_focal_path_length - new_attachment_point);
                }
            }
        }
    else if (a_on_path && !c_on_path)
        {
        if (u < 0.5)
            {
            _case = 5;

            // (a)    b*
            //   \   /      
            //    \ /       
            //    (x)   (d)
            //      \   /   
            //       \ /    
            //       (y)    
            //        |     
            //        |     
            //        c

            if (new_attachment_point > _new_edgelen_top + _new_edgelen_middle)
                {
                _topology_changed = true;
                _tree_manipulator->nniNodeSwap(_a, _d);
                _a->setEdgeLength(_new_edgelen_top + _new_edgelen_middle);
                _x->setEdgeLength(new_attachment_point - _a->getEdgeLength());
                _d->setEdgeLength(new_focal_path_length - new_attachment_point);
                }
            else
                {
                _topology_changed = false;
                _a->setEdgeLength(new_attachment_point);
                _x->setEdgeLength(_new_edgelen_top + _new_edgelen_middle - new_attachment_point);
                _d->setEdgeLength(_new_edgelen_bottom);
                }
            }
        else
            {
            _case = 6;

            // (a)    b
            //   \   /      
            //    \ /       
            //    (x)   (d)
            //      \   /   
            //       \ /    
            //       (y)    
            //        |     
            //        |     
            //        c*

            if (new_attachment_point < _new_edgelen_top)
                {
                _topology_changed = true;
                _tree_manipulator->nniNodeSwap(_a, _d);
                _d->setEdgeLength(_new_edgelen_bottom + _new_edgelen_middle);
                _x->setEdgeLength(_new_edgelen_top - new_attachment_point);
                _a->setEdgeLength(new_attachment_point);
                }
            else
                {
                _topology_changed = false;
                _a->setEdgeLength(_new_edgelen_top);
                _x->setEdgeLength(new_attachment_point - _new_edgelen_top);
                _d->setEdgeLength(new_focal_path_length - new_attachment_point);
                }
            }
        }
    else
        {
        if (u < 0.5)
            {
            _case = 7;

            //  a*   (b)
            //   \   /      
            //    \ /       
            //    (x)   (d)
            //      \   /   
            //       \ /    
            //       (y)    
            //        |     
            //        |     
            //        c

            if (new_attachment_point > _new_edgelen_top + _new_edgelen_middle)
                {
                _topology_changed = true;
                _tree_manipulator->nniNodeSwap(_b, _d);
                _d->setEdgeLength(new_focal_path_length - new_attachment_point);
                _x->setEdgeLength(new_attachment_point - _new_edgelen_top - _new_edgelen_middle);
                _b->setEdgeLength(_new_edgelen_top + _new_edgelen_middle);
                }
            else
                {
                _topology_changed = false;
                _b->setEdgeLength(new_attachment_point);
                _x->setEdgeLength(_new_edgelen_top + _new_edgelen_middle - new_attachment_point);
                _d->setEdgeLength(_new_edgelen_bottom);
                }
            }
        else
            {
            _case = 8;

            //  a    (b)
            //   \   /      
            //    \ /       
            //    (x)   (d)
            //      \   /   
            //       \ /    
            //       (y)    
            //        |     
            //        |     
            //        c*

            if (new_attachment_point < _new_edgelen_top)
                {
                _topology_changed = true;
                _tree_manipulator->nniNodeSwap(_b, _d);
                _b->setEdgeLength(new_attachment_point);
                _x->setEdgeLength(_new_edgelen_top - new_attachment_point);
                _d->setEdgeLength(_new_edgelen_middle + _new_edgelen_bottom);
                }
            else
                {
                _topology_changed = false;
                _b->setEdgeLength(_new_edgelen_top);
                _x->setEdgeLength(new_attachment_point - _new_edgelen_top);
                _d->setEdgeLength(new_focal_path_length - new_attachment_point);
                }
            }
        }

    assert(_a->getEdgeLength() >= _smallest_edge_length);
    assert(_b->getEdgeLength() >= _smallest_edge_length);
    assert(_d->getEdgeLength() >= _smallest_edge_length);
    assert(_x->getEdgeLength() >= _smallest_edge_length);
    assert(_y->getEdgeLength() >= _smallest_edge_length);
    }

The revert function

This function is called if the proposed new state is rejected. It simply returns the tree to its state prior to the last call of proposeNewState function.

template <class T>
inline void TreeUpdater<T>::revert()
    {
    assert(_case > 0 && _case < 9);
    if (_case == 1 || _case == 2)
        {
        if (_topology_changed)
            {
            _tree_manipulator->nniNodeSwap(_d, _b);
            _a->setEdgeLength(_orig_edgelen_top);
            _x->setEdgeLength(_orig_edgelen_middle);
            _y->setEdgeLength(_orig_edgelen_bottom);
            }
        else
            {
            _a->setEdgeLength(_orig_edgelen_top);
            _x->setEdgeLength(_orig_edgelen_middle);
            _y->setEdgeLength(_orig_edgelen_bottom);
            }
        }
    else if (_case == 3 || _case == 4)
        {
        if (_topology_changed)
            {
            _tree_manipulator->nniNodeSwap(_d, _a);
            _b->setEdgeLength(_orig_edgelen_top);
            _x->setEdgeLength(_orig_edgelen_middle);
            _y->setEdgeLength(_orig_edgelen_bottom);
            }
        else
            {
            _b->setEdgeLength(_orig_edgelen_top);
            _x->setEdgeLength(_orig_edgelen_middle);
            _y->setEdgeLength(_orig_edgelen_bottom);
            }
        }
    else if (_case == 5 || _case == 6)
        {
        if (_topology_changed)
            {
            _tree_manipulator->nniNodeSwap(_d, _a);
            _a->setEdgeLength(_orig_edgelen_top);
            _x->setEdgeLength(_orig_edgelen_middle);
            _d->setEdgeLength(_orig_edgelen_bottom);
            }
        else
            {
            _a->setEdgeLength(_orig_edgelen_top);
            _x->setEdgeLength(_orig_edgelen_middle);
            _d->setEdgeLength(_orig_edgelen_bottom);
            }
        }
    else if (_case == 7 || _case == 8)
        {
        if (_topology_changed)
            {
            _tree_manipulator->nniNodeSwap(_d, _b);
            _b->setEdgeLength(_orig_edgelen_top);
            _x->setEdgeLength(_orig_edgelen_middle);
            _d->setEdgeLength(_orig_edgelen_bottom);
            }
        else
            {
            _b->setEdgeLength(_orig_edgelen_top);
            _x->setEdgeLength(_orig_edgelen_middle);
            _d->setEdgeLength(_orig_edgelen_bottom);
            }
        }
    }

The update function

The final function orchestrates the proposal by calling proposeNewState and then deciding whether to accept or reject it. The "temporary debugging output" will be removed soon, but for now this provides our only insight into how the MCMC is going.

template <class T>
inline typename TreeUpdater<T>::kernel_t TreeUpdater<T>::update()
    {
    if (_prev_logkernel.first == 0.0)
        _prev_logkernel = calcLogPosteriorKernel();

    proposeNewState();

    kernel_t logkernel = calcLogPosteriorKernel();
    double log_likelihood_diff = logkernel.first - _prev_logkernel.first;
    double log_prior_diff      = logkernel.second  - _prev_logkernel.second;
    double log_diff            = _log_hastings_ratio + log_likelihood_diff + log_prior_diff;

    bool accept = true;
    double logu = _lot->logUniform();
    if (logu > log_diff)
        accept = false;

    if (accept)
        {
        _naccepts++;
        _prev_logkernel = logkernel;
        }
    else
        {
        revert();
        }
        
    tune(accept);
    _nattempts++;

    // temporary debugging output
    double pctaccept = 100.0*_naccepts/_nattempts;
    double TL = _tree_manipulator->calcTreeLength();
    std::cout << boost::str(boost::format("%12s %12s %12.1f %12.5f %12.5f %12.5f %12.5f %12.5f %12.5f") % (_topology_changed ? "*" : " ") % (accept ? "accept" : "reject") % pctaccept % _lambda % logu % log_diff % _prev_logkernel.first % _prev_logkernel.second % TL) << std::endl;
    reset();

    return logkernel;
    }