function [alpha, beta, gamma, loglik, xi, gamma2] = fwdback(varargin)
% FWDBACK Compute the posterior probs. in an HMM using the forwards backwards algo.
%
% [alpha, beta, gamma, loglik, xi, gamma2] = fwdback('param_name1', args1, ...)
% 
% Notation:
% Y(t) = observation, Q(t) = hidden state, M(t) = mixture variable (for MOG outputs)
% A(t) = discrete input (action) (for POMDP models)
%
% Parameter names are shown below; default values in [] - if none, argument is mandatory.
%
% 'init_state_distrib' - init_state_distrib(i) = Pr(Q(1) = i)  []
% 'transmat' - transmat(i,j) = Pr(Q(t) = j | Q(t-1)=i)  []
% 'obslik' - obslik(i,t) = Pr(Y(t)| Q(t)=i)  []
%   (Compute obslik using eval_pdf_xxx on your data sequence first.)
%
% For HMMs with MOG outputs:
% 'obslik2' - obslik(i,j,t) = Pr(Y(t)| Q(t)=i,M(t)=j)  []
% 'mixmat' - mixmat(i,j) = Pr(M(t) = j | Q(t)=i)  []
%
% For HMMs with discrete inputs:
% 'transmat' - transmat(i,j,a) = Pr(Q(t) = j | Q(t-1)=i, A(t-1)=a)  []
% 'act' - act(t) = action performed at step t
%
% Optional arguments:
% 'fwd_only' - if 1, only do a forwards pass and set beta=[], gamma2=[]  [0]
% 'scaled' - if 1,  normalize alphas and betas to prevent underflow [1]
% 'maximize' - if 1, do max-product (used by Viterbi) instead of sum-product [0]
%
% OUTPUTS:
% alpha(i,t) = p(Q(t)=i | y(1:t)) (or p(Q(t)=i, y(1:t)) if scaled=0)
% beta(i,t) = p(y(t+1:T) | Q(t)=i)*p(y(t+1:T)|y(1:t)) (or p(y(t+1:T) | Q(t)=i) if scaled=0)
% gamma(i,t) = p(Q(t)=i | y(1:T))
% loglik = log p(y(1:T))
% xi(i,j,t-1)  = p(Q(t-1)=i, Q(t)=j | y(1:T)) 
% gamma2(j,k,t) = p(Q(t)=j, M(t)=k | y(1:T)) (only for MOG  outputs)
%
% If fwd_only = 1, these become
% alpha(i,t) = p(Q(t)=i | y(1:t))
% beta = []
% gamma(i,t) = p(Q(t)=i | y(1:t))
% xi(i,j,t-1)  = p(Q(t-1)=i, Q(t)=j | y(1:t))
% gamma2 = []
%
% Note: we only compute xi if it is requested as a return argument, since it can be very large.
% Similarly, we only compute gamma2 on request (and if using MOG outputs).
%
% Example:
% [alpha, beta, gamma, loglik] = fwdback('init_state_distrib', pi, 'transmat', A, 'obslik', B);


if nargout >= 5, compute_xi = 1; end
if nargout >= 6, compute_gamma2 = 1; end

[init_state_distrib, transmat, obslik, obslik2, mixmat, fwd_only, scaled, act, maximize] = ...
    process_options(varargin, ...
	'init_state', [], 'transmat', [], 'obslik', [], 'obslik2', [], 'mixmat', [], ...
	'fwd_only', 0, 'scaled', 1, 'act', [], 'maximize', 0);

assert(~maximize)

[Q T] = size(obslik);

if isempty(obslik2)
  compute_gamma2 = 0;
end

if isempty(act)
  act = ones(1,T);
end

scale = ones(1,T);

% scale(t) = Pr(O(t) | O(1:t-1)) = 1/c(t) as defined by Rabiner (1989).
% Hence prod_t scale(t) = Pr(O(1)) Pr(O(2)|O(1)) Pr(O(3) | O(1:2)) ... = Pr(O(1), ... ,O(T)) 
% or log P = sum_t log scale(t).
% Rabiner suggests multiplying beta(t) by scale(t), but we can instead
% normalise beta(t) - the constants will cancel when we compute gamma.

loglik = 0;
prior = init_state_distrib(:); 

alpha = zeros(Q,T);
gamma = zeros(Q,T);
if compute_xi
  xi = zeros(Q,Q,T-1);
else
  xi = [];
end

%%%%%%%%% Forwards %%%%%%%%%%

t = 1;
alpha(:,1) = prior .* obslik(:,t);
if scaled
  [alpha(:,t), scale(t)] = normalise(alpha(:,t));
end
for t=2:T
  trans = transmat(:,:,act(t-1))';
  alpha(:,t) = (trans * alpha(:,t-1)) .* obslik(:,t);
  if scaled
    [alpha(:,t), scale(t)] = normalise(alpha(:,t));
  end
  if compute_xi & fwd_only  % useful for online EM
    xi(:,:,t-1) = normalise((alpha(:,t-1) * obslik(:,t)') .* trans);
  end
end
if scaled
  loglik = sum(log(scale));
else
  loglik = log(sum(alpha(:,T)));
end

if fwd_only
  gamma = alpha;
  beta = [];
  gamma2 = [];
  return;
end


%%%%%%%%% Backwards %%%%%%%%%%

beta = zeros(Q,T);
if compute_gamma2
  M = size(mixmat, 2);
  gamma2 = zeros(Q,M,T);
else
  gamma2 = [];
end

beta(:,T) = ones(Q,1);
gamma(:,T) = normalise(alpha(:,T) .* beta(:,T));
t=T;
if compute_gamma2
  gamma2(:,:,t) = normalise(obslik2(:,:,t) .* mixmat .* repmat(gamma(:,t), [1 M]));
end
for t=T-1:-1:1
  b = beta(:,t+1) .* obslik(:,t+1); 
  trans = transmat(:,:,act(t));
  beta(:,t) = trans * b;
  if scaled 
    beta(:,t) = normalise(beta(:,t);
  end
  gamma(:,t) = normalise(alpha(:,t) .* beta(:,t));
  if compute_xi
    xi(:,:,t) = normalise((trans .* (alpha(:,t) * b')));
  end
  if compute_gamma2
    gamma2(:,:,t) = normalise(obslik2(:,:,t) .* mixmat .* repmat(gamma(:,t), [1 M]));
  end
end


% We now explain the equation for gamma2
% Let zt=y(1:t-1,t+1:T) be all observations except y(t)
% gamma2(Q,M,t) = P(Qt,Mt|yt,zt) = P(yt|Qt,Mt,zt) P(Qt,Mt|zt) / P(yt|zt)
%                = P(yt|Qt,Mt) P(Mt|Qt) P(Qt|zt) / P(yt|zt)
% Now gamma(Q,t) = P(Qt|yt,zt) = P(yt|Qt) P(Qt|zt) / P(yt|zt)
% hence
% P(Qt,Mt|yt,zt) = P(yt|Qt,Mt) P(Mt|Qt) [P(Qt|yt,zt) P(yt|zt) / P(yt|Qt)] / P(yt|zt)
%                = P(yt|Qt,Mt) P(Mt|Qt) P(Qt|yt,zt) / P(yt|Qt)
%
