function distance_J_B = bhattacharyya_distance_GMM( hmm1, hmm2, mode, varargin )
% function distance_J_B = bhattacharyya_distance_GMM( hmm1, hmm2, mode, [part = 'all parts'] )
%
% part values: 	'mean_part_only' or 'cov_part_only'
%
% CVS_Version_String = '$Id: bhattacharyya_distance_GMM.m,v 1.2 2004/04/07 18:09:51 tuerk Exp $';
% CVS_Name_String = '$Name: rel-1-4-01 $';

% ###########################################################
%
% This file is part of the matlab scripts of the MASV System.
% MASV = Munich Automatic Speaker Verification
%
% Copyright 2002-2003, Ulrich Trk
% Institute of Phonetics and Speech Communication
% University of Munich
% tuerk@phonetik.uni-muenchen.de
%
%
%   MASV is free software; you can redistribute it and/or modify
%   it under the terms of the GNU General Public License as published by
%   the Free Software Foundation; either version 2 of the License, or
%   (at your option) any later version.
%
%   MASV is distributed in the hope that it will be useful,
%   but WITHOUT ANY WARRANTY; without even the implied warranty of
%   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
%   GNU General Public License for more details.
%
%   You should have received a copy of the GNU General Public License
%   along with MASV; if not, write to the Free Software
%   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
%
% ###########################################################



	if ( (hmm1.num_of_states ~= 3) | (hmm2.num_of_states ~=3 )),
		error('can only handle single state hmms!');
	end

	if ( (ndims(hmm1.state(1).variance_vector) ~= 2) | (ndims(hmm2.state(1).variance_vector) ~= 2 )),
		error('can only handle diagonal covariance hmms!');
	end
	
	if (nargin>3),
		part = varargin{1};
	else
		part = 'all_parts';
	end
	
	if (strcmp( mode, 'use_max_mixture')),
		
		[max_mix_1, max_mix_index_1]=max(hmm1.state(1).mix_weight);
		[max_mix_2, max_mix_index_2]=max(hmm2.state(1).mix_weight);
		
		mean1 = hmm1.state(1).mean_vector(max_mix_index_1,:);
		mean2 = hmm2.state(1).mean_vector(max_mix_index_2,:);
		
		if (size(mean1,1) ~= 1),
			mean1 = mean1';
		end
		if (size(mean2,1) ~= 1),
			mean2 = mean2';
		end
		C1 = diag(hmm1.state(1).variance_vector(max_mix_index_1,:));
		C2 = diag(hmm2.state(1).variance_vector(max_mix_index_2,:));
		
		C1_plus_C2_2 = ( C1 + C2 ) ./ 2;
		
		% mean are row vectors
		mean_part = ( 1 / 8 ) *  (mean1 - mean2) * inv(C1_plus_C2_2) * (mean1 - mean2)';
		cov_part = ( 1 / 2 ) * log(  det(C1_plus_C2_2) * (1 / sqrt(det(C1)))  * ( 1 / sqrt(det(C2)) ) );
		
		if (strcmp(part, 'mean_part_only')),
			distance_J_B = mean_part;
		elseif (strcmp(part, 'cov_part_only')),
			distance_J_B = cov_part;
		else
			distance_J_B = mean_part + cov_part;
		end
		
	end	
	
return;