function distance_J_D = divergence_distance_GMM( hmm1, hmm2, mode )
	
	if ( (hmm1.num_of_states ~= 3) | (hmm2.num_of_states ~=3 )),
		error('can only handle single state hmms!');
	end

	if ( (ndims(hmm1.state(1).variance_vector) ~= 2) | (ndims(hmm2.state(1).variance_vector) ~= 2 )),
		error('can only handle diagonal covariance hmms!');
	end
	
	if (strcmp( mode, 'use_max_mixture')),
		
		[max_mix_1, max_mix_index_1]=max(hmm1.state(1).mix_weight);
		[max_mix_2, max_mix_index_2]=max(hmm2.state(1).mix_weight);
		
		mean1 = hmm1.state(1).mean_vector(max_mix_index_1,:);
		mean2 = hmm2.state(1).mean_vector(max_mix_index_2,:);
		
		if (size(mean1,1) ~= 1),
			mean1 = mean1';
		end
		if (size(mean2,1) ~= 1),
			mean2 = mean2';
		end
		C1 = diag(hmm1.state(1).variance_vector(max_mix_index_1,:));
		C2 = diag(hmm2.state(1).variance_vector(max_mix_index_2,:));
		
		trace_part = trace( inv(C2)*C1 + inv(C1)*C2 - 2*eye(size(C1)));
		
		% mean are row vectors
		distance_J_D = ( 1 / 2 ) *  (mean1 - mean2) * (inv(C1) + inv(C2)) * (mean1 - mean2)' + ( 1 / 2 ) * trace_part;
		
		
	end	
	
return;