import sys
import locale as lo
import os
import re
import math
import argparse
import datetime, time
import xml.etree.cElementTree as ET

import pdb


def is_valid_par_file(file_argument):
    """
    Checks if --INP argument points to a file with .par extension
    """
    file_ext = os.path.splitext(file_argument)[1]
    if not os.path.exists(file_argument):
        sys.exit("ERROR: mausbpf2eaf : The file %s does not exist! Please specify a valid BAS '.par' file as input. Exiting..." % file_argument)
    elif file_ext == '.par' or file_ext == '.PAR' or file_ext == '.bpf' or file_ext == '.BPF':
        return file_argument
    else:
        sys.exit("ERROR: mausbpf2eaf : The extension %s is not a valid extension! Please specify a valid BAS '.par' file as input. Exiting..." % file_ext)

def is_valid_output_file(file_path):
    """
    #Checks if --OUT argument is valid
    """
    file_ext = os.path.splitext(file_path)[1]
    if file_ext != '.eaf':
        sys.exit("ERROR: mausbpf2eaf : The extension '%s' is not a valid output extension. Please specify a '.eaf' output file. Exiting..." %file_ext)
    return True
    


def write_elan_data(data, out_file, orig_wav_path, verbosity_level):
    if verbosity_level >= 1:
        print("DEBUG: mausbpf2eaf : Constructing ELAN XML Data")
    if verbosity_level >= 2:
        print("DEBUG: mausbpf2eaf : Constructing .eaf header information")
    if verbosity_level >= 3:
        print("DEBUG: mausbpf2eaf : Defining XML Namespace and creating annotation document")
    NS = "http://www.w3.org/2001/XMLSchema-instance"
    location_attribute = '{%s}noNamespaceSchemaLocation' % NS
    annot_doc = ET.Element("ANNOTATION_DOCUMENT", attrib={location_attribute: "http://www.mpi.nl/tools/elan/EAFv3.0.xsd"})
    
    if verbosity_level >= 3:
        print("DEBUG: mausbpf2eaf : Calculating ISO timezone offset and setting DATE information.")
    #Calculate ISO offset and set date to utc with ISO offset in correct format
    utc_offset_sec = time.altzone if time.localtime().tm_isdst else time.timezone
    utc_offset = datetime.timedelta(seconds=-utc_offset_sec)
    annot_doc.set("DATE", datetime.datetime.utcnow().replace(microsecond=0).replace(tzinfo=datetime.timezone(offset=utc_offset)).isoformat())
    
    if verbosity_level >= 3:
        print("DEBUG: mausbpf2eaf : Setting AUTHOR as MAUS")
    annot_doc.set("AUTHOR", "MAUS")
    if verbosity_level >= 3:
        print("DEBUG: mausbpf2eaf : Setting VERSION to 3.0")
    annot_doc.set("VERSION", "3.0")
    
    header = ET.SubElement(annot_doc, "HEADER")
    header.set("TIME_UNITS", "milliseconds")
   
    #Media Descriptor 
    if verbosity_level >= 3:
        print("DEBUG: mausbpf2eaf : Defining reference audio file to be at '%s' and having Mime_type: 'audio/x-wav" % os.path.abspath(orig_wav_path))
    media_desc = ET.SubElement(header, "MEDIA_DESCRIPTOR")
    media_desc.set("MEDIA_URL", "file://" + os.path.abspath(orig_wav_path)) 
    
    """
    #Check for different MIME Types(Flac, MP3)
    NOT YET IMPLEMENTED IN ELAN (SEE  https://tla.mpi.nl/wp-content/uploads/2016/12/Video_encoding_guidelines_ELAN.pdf)
    
    if os.path.isfile(os.path.abspath(orig_wav_path)):
        mime = magic.Magic(mime=True)
        media_mime_type = mime.from_file(os.path.abspath(orig_wav_path))
        
        media_desc.set("MIME_TYPE", media_mime_type)
    else:
        media_desc.set("MIME_TYPE", "audio/x-wav")
        sys.stderr.write("WARNING: mausbpf2eaf : There is no audio file at '%s' from which to infer the mime type, assuming 'audio/x-wav'.\n"% os.path.abspath(orig_wav_path))
    """
    media_desc.set("MIME_TYPE", "audio/x-wav")
    media_desc.set("RELATIVE_MEDIA_URL", "./" + orig_wav_path)
    
    #Property (Contains lastAnnotationId = maxCount of Annotation elments)
    prop = ET.SubElement(header, "PROPERTY")
    
    #Make the time order
    make_time_order(data, annot_doc, verbosity_level)
    
    
    max_annot_idx = 0
    max_annot_idx = make_tiers(data, annot_doc, verbosity_level)

    prop.set("NAME", "lastUsedAnnotationId")
    prop.text = str(max_annot_idx)
    
    
    #Linguistic Type
    if verbosity_level >= 3:
        print("DEBUG: mausbpf2eaf : Defining LINGUISTIC_TYPEs")
    ling_type = ET.SubElement(annot_doc, "LINGUISTIC_TYPE")
    ling_type.set("LINGUISTIC_TYPE_ID", "base")
    ling_type.set("TIME_ALIGNABLE", "true")
    
    ling_type_par = ET.SubElement(annot_doc, "LINGUISTIC_TYPE")
    ling_type_par.set("LINGUISTIC_TYPE_ID", "parent")
    ling_type_par.set("CONSTRAINTS", "Included_In")
    ling_type_par.set("TIME_ALIGNABLE", "true")

    #Locale
    if verbosity_level >= 3:
        print("DEBUG: mausbpf2eaf : Setting LOCALE to '%s'." % lo.getdefaultlocale()[0])
    
    country_code = lo.getdefaultlocale()[0].split('_')[0]
    language_code = lo.getdefaultlocale()[0].split('_')[1]
    locale = ET.SubElement(annot_doc, "LOCALE")
    locale.set("COUNTRY_CODE", "DE")
    locale.set("LANGUAGE_CODE", "de")

    #Pretty print the XML
    if verbosity_level >= 3:
        print("DEBUG: mausbpf2eaf : Formatting the XML to be human readable")
    indent(annot_doc)

    if verbosity_level >= 1:
        print("DEBUG: mausbpf2eaf : Writing XML to file %s" % out_file)
    tree = ET.ElementTree(annot_doc)
    tree.write(out_file, xml_declaration=True, encoding="utf-8", method="xml")
    

def indent(elem, level=0):
    i = "\n" + level*"  "
    if len(elem):
        if not elem.text or not elem.text.strip():
            elem.text = i + "  "
        if not elem.tail or not elem.tail.strip():
            elem.tail = i
        for elem in elem:
            indent(elem, level+1)
        if not elem.tail or not elem.tail.strip():
            elem.tail = i
    else:
        if level and (not elem.tail or not elem.tail.strip()):
            elem.tail = i


def make_time_order(data, annot_doc, verbosity_level):
    """
    Constructs the time order from all existing class4 tracks.
    Every track is read and for every start and end point of an utterance a single node is created.
    The nodes are later matched to annotations by their index in make_tiers().
    """
    class4_list = []
    for time_tier in data["class4"]:
        if data["class4"][time_tier]:
            class4_list.append(time_tier)
    if not class4_list:
        sys.exit("ERROR: mausbpf2eaf : Could not find any time-alignable tracks in the input file. The .par input must at least contain one of the following tracks: %s. Exiting..." %", ".join(data["class4"]))
    
    elif class4_list == ['TRN']:
        sys.stderr.write("WARNING: mausbpf2eaf : The input file only contains a TRN track which is not enough to build hierarchy information. Only writing out TRN data...")

    if verbosity_level >= 2:
        class4_list_output = ", ".join(class4_list)
        print("DEBUG: mausbpf2eaf : Constructing Time reference Levels from found class4 tiers: %s" %class4_list_output)

    time_order = ET.SubElement(annot_doc, "TIME_ORDER")
    sample_rate = int(data["header"]["SAM"])
    if not data["header"]["SAM"]:
        sys.exit("ERROR: mausbpf2eaf : Could not read sample rate from input file. Please make sure there is a valid value in SAM. Exiting...")

    all_times = []
    for track_name,time_values in data["class4"].items():
        if time_values:
            idx = 0
            for value in time_values:
                #Start
                try:
                    start_time = math.floor(int(value.split()[0])/sample_rate*1000)
                except ValueError:
                    sys.exit("ERROR: mausbpf2eaf : The input file seems to be malformed at value '%s' in track '%s'. Please reformat the input." % (value, track_name)) 
                time_id = track_name.lower() + str(idx)
                idx += 1
                all_times.append((start_time,time_id))
                try:
                    end_time = math.floor((int(value.split()[0]) + int(value.split()[1]) + 1)/sample_rate*1000)
                except ValueError:
                    sys.exit("ERROR: mausbpf2eaf : The input file seems to be malformed at value '%s' in track '%s'. Please reformat the input." % (value, track_name)) 
                time_id = track_name.lower() + str(idx)
                all_times.append((end_time,time_id))
                idx += 1
    
    for track_name,time_values in data["class2"].items():
        if time_values:
            idx = 0
            for value in time_values:
                #Start
                try:
                    start_time = math.floor(int(value.split()[0])/sample_rate*1000)
                except ValueError:
                    sys.exit("ERROR: mausbpf2eaf : The input file seems to be malformed at value '%s' in track '%s'. Please reformat the input." % (value, track_name)) 
                time_id = track_name.lower() + str(idx)
                idx += 1
                all_times.append((start_time,time_id))
                try:
                    end_time = math.floor((int(value.split()[0]) + int(value.split()[1]) + 1)/sample_rate*1000)
                except ValueError:
                    sys.exit("ERROR: mausbpf2eaf : The input file seems to be malformed at value '%s' in track '%s'. Please reformat the input." % (value, track_name)) 
                time_id = track_name.lower() + str(idx)
                all_times.append((end_time,time_id))
                idx += 1

    #Sort the time order entries by ascending order instead of all of one track and then all of the next track
    all_times.sort(key=lambda time_tuple: time_tuple[0])
    for time_tuple in all_times:
        time_slot = ET.SubElement(time_order, "TIME_SLOT")
        time_slot.set("TIME_SLOT_ID", str(time_tuple[1]))
        time_slot.set("TIME_VALUE", str(time_tuple[0]))

            
def get_time_reference_tier(data):
    """
    Defines the reference tier to be used for class1 and class1mult annotations, so they can be matched to specific times.
    The default is set to MAU.
    """
    time_reference_tier = None
    if data["class4"]["MAU"]:
        time_reference_tier = "MAU"
    elif data["class4"]["SAP"]:
        time_reference_tier = "SAP"
    elif data["class4"]["WOR"]:
        time_reference_tier = "WOR"
    elif data["class4"]["PHO"]:
        time_reference_tier = "PHO"
    elif data["class4"]["MAS"]:
        time_reference_tier = "MAS"
        
    return time_reference_tier

def make_tiers(data, annot_doc, verbosity_level):
    """
    Constructs the annotation tiers and the containing annotations.
    The reference used is the word index in case of class1 and class1mult and the literal index (i.e line occurence) in case of class4.
    The index is then matched to a corresponding time order object from make_time_order().
    In case of class1 and class1mult the time points are computed by taking the word_number value and parsing through the class4 entries for that value, to find the start and end time.
    """
    time_reference_tier = get_time_reference_tier(data)
 
    annot_idx = 0
    if time_reference_tier is not None:
        for track_class_name,track_data in data.items():
            if track_class_name is not "header":
                for track_name,track_items in track_data.items():
                    if track_items:
                        if verbosity_level >= 2:
                            print("DEBUG: mausbpf2eaf : Creating annotations for Tier: '%s'" % track_name)
                        tier = ET.SubElement(annot_doc, "TIER")
                        tier.set("ANNOTATOR", "MAUS")
                        tier.set("TIER_ID", track_name)
                        idx = 0
                        for track_value in track_items:
                            annotation = ET.SubElement(tier, "ANNOTATION")
                            if verbosity_level >= 3:
                                print("DEBUG: mausbpf2eaf : Creating annotation: '%s'" % track_value)
                            
                            align_annotation = ET.SubElement(annotation, "ALIGNABLE_ANNOTATION")
                            align_annotation.set("ANNOTATION_ID", "a" +  str(annot_idx))
                            annot_idx += 1
                            
                            annotation_val = ET.SubElement(align_annotation, "ANNOTATION_VALUE")
                            if track_class_name == "class4":
                                align_annotation.set("TIME_SLOT_REF1", track_name.lower() + str(idx))
                                align_annotation.set("TIME_SLOT_REF2", track_name.lower() + str(idx+1))
                                
                                normalized_value_array = track_value.replace(" ","\t").split('\t')
                                annotation_val.text = " ".join(normalized_value_array[3:])

                                tier.set("LINGUISTIC_TYPE_REF", "base")
                                
                                idx += 2
                            
                            elif track_class_name == "class2":
                                align_annotation.set("TIME_SLOT_REF1", track_name.lower() + str(idx))
                                align_annotation.set("TIME_SLOT_REF2", track_name.lower() + str(idx+1))
                                
                                normalized_value_array = track_value.replace(" ","\t").split('\t')
                                annotation_val.text = " ".join(normalized_value_array[3])

                                tier.set("LINGUISTIC_TYPE_REF", "base")
                                
                                idx += 2

                            elif track_class_name == "class1":
                                start,end = get_time_info(str(idx), time_reference_tier, data)
                                align_annotation.set("TIME_SLOT_REF1", time_reference_tier.lower() + str(start))
                                align_annotation.set("TIME_SLOT_REF2", time_reference_tier.lower() + str(end))
                                
                                #Get first element after number and whitespace
                                start_of_value = re.compile('\d+\s').match(track_value).end()
                                annotation_val.text = track_value[start_of_value:]
                                
                                tier.set("PARENT_REF", time_reference_tier)
                                tier.set("LINGUISTIC_TYPE_REF", "parent")
                                
                                idx += 1

                                
                            if track_class_name == "class1mult":
                                time_reference = track_value.replace(" ","\t").split('\t')[0]
                                start,end = get_time_info_mult(time_reference, time_reference_tier, data)
                                align_annotation.set("TIME_SLOT_REF1", time_reference_tier.lower() + str(start))
                                align_annotation.set("TIME_SLOT_REF2", time_reference_tier.lower() + str(end))
                                
                                normalized_value_array = track_value.replace(" ","\t").split('\t')
                                annotation_val.text = " ".join(normalized_value_array[1:])
                                
                                tier.set("PARENT_REF", time_reference_tier)
                                tier.set("LINGUISTIC_TYPE_REF", "parent")
    #Case there only is TRN as class4 (time_reference_tier returns None and program didn't quit while parsing because no class4 exists)
    else:
        if verbosity_level >= 2:
            print("DEBUG: mausbpf2eaf : Creating annotations for Tier: '%s'" % track_name)
        tier = ET.SubElement(annot_doc, "TIER")
        tier.set("ANNOTATOR", "MAUS")
        tier.set("TIER_ID", "TRN")
        idx = 0
        for track_value in data["class4"]["TRN"]:
            annotation = ET.SubElement(tier, "ANNOTATION")
            if verbosity_level >= 3:
                print("DEBUG: mausbpf2eaf : Creating annotation: '%s'" % track_value)
            
            align_annotation = ET.SubElement(annotation, "ALIGNABLE_ANNOTATION")
            align_annotation.set("ANNOTATION_ID", "a" +  str(annot_idx))
            annot_idx += 1
            
            annotation_val = ET.SubElement(align_annotation, "ANNOTATION_VALUE")
            
            align_annotation.set("TIME_SLOT_REF1", "trn" + str(idx))
            align_annotation.set("TIME_SLOT_REF2", "trn" + str(idx+1))
            
            normalized_value_array = track_value.replace(" ","\t").split('\t')
            annotation_val.text = " ".join(normalized_value_array[3:])

            tier.set("LINGUISTIC_TYPE_REF", "base")
            
            idx += 2

    return annot_idx-1

def get_time_info_mult(time_reference, time_reference_tier, data):
    start,end = 0,0
    if ';' in time_reference:
        first_word = time_reference.split(';')[0]
        second_word = time_reference.split(';')[1]
        wrong_start,start = get_time_info(first_word, time_reference_tier, data)
        end,wrong_end = get_time_info(second_word, time_reference_tier, data)
    elif ',' in time_reference:
        first_word = time_reference.split(',')[0]
        last_word = time_reference.split(',')[-1]
        start,wrong_end = get_time_info(first_word, time_reference_tier, data)
        wrong_start,end = get_time_info(last_word, time_reference_tier, data)
    else:
        start,end = get_time_info(time_reference, time_reference_tier, data)
    return start,end

def get_time_info(word_index, time_reference_tier, data):
    """
    It's always index *2 because ELAN does not allow for the same time order element to be used 
    by different annotations without hierarchy reference. That's why it was decided
    to make a time reference element for every start and end point even though that duplicates entries.
    """
    possible_values = []
    for time_stamp in data["class4"][time_reference_tier]:
        time_stamp_class1_ref = int(time_stamp.replace(" ","\t").split('\t')[2])
        possible_values.append(time_stamp_class1_ref)

    if int(word_index) not in possible_values:
        sys.stderr.write("WARNING: mausbpf2eaf : The input file contains a class1 reference: %d ,that does not occur in time reference track: '%s'.\n" % (int(word_index), time_reference_tier))

    first_index = 0
    last_index = 0
    for value in data["class4"][time_reference_tier]:
        class4_kan_ref = value.replace(" ","\t").split('\t')[2]
        if class4_kan_ref == word_index and first_index == 0:
            first_index = data["class4"][time_reference_tier].index(value) * 2
            last_index = data["class4"][time_reference_tier].index(value) * 2 + 1
        elif class4_kan_ref == word_index and first_index != 0:
            last_index = data["class4"][time_reference_tier].index(value) * 2 + 1
  
    
    return first_index,last_index

    

class ParParser: 
    """
    Parses Partitur file and writes all supported tracks in the dictionary below.
    Only header information currently in use is SAM, as the samplerate is needed for conversion.
    """
    data = {"header":{"SAM": 0, "MAO" : None, "SAO": None, "REP": None, "SPN": None },
           "class1":{"ORT": [], "KAN": [], "KSS": [], "KAS": [], "SPK": [], "MRP": []},
           "class1mult":{"NOI": [], "TRL": [], "TR2": [], "TRO": []},
           "class2":{"SPD": []},
           "class4":{"MAU": [], "MAS": [], "PHO": [], "WOR": [], "SAP": [], "TRN": []}}


    def __init__(self, filepath):
        self.source_file = os.path.abspath(filepath)
   

    def read_file(self, verbosity_level):
        
        if verbosity_level >= 1:
            print("DEBUG: mausbpf2eaf : Parsing partitur file '%s'" % self.source_file)
        
        with open(self.source_file, 'r') as source_data:
            lines = source_data.readlines()
            for line in lines:
                for track_class in self.data:
                    if track_class == "header":
                        for track_name in self.data[track_class]:
                            if line.startswith(track_name):
                                self.data[track_class][track_name] = line.split()[1]
                    else:
                        for track_name in self.data[track_class]:
                            if line.startswith(track_name):
                                self.data[track_class][track_name].append(line.rstrip()[5:])
        
        if verbosity_level >= 2:
            found_tracks = []
            for track_class in self.data:
                for track in self.data[track_class]:
                    if self.data[track_class][track]:
                        found_tracks.append(track)
            print("DEBUG: mausbpf2eaf : Found Tracks: '%s'" % ", ".join(found_tracks))
            if 'class1mult' and not 'class1':
                sys.exit("ERROR: mausbpf2eaf : Found a class1 track with multiple references but no plain class1 track that can serve as reference level. Exiting...")
        return(self.data)


if __name__ == "__main__":
    argument_parser = argparse.ArgumentParser(description="Parse Input Arguments.")
    argument_parser.add_argument("--INP", metavar="FILE",
                                help="Input BPF .par file to be converted")
    
    argument_parser.add_argument("--OUT", metavar="FILE", help="Output path where the final .eaf file will be stored")
    
    argument_parser.add_argument("--verbose", "-v", default=0, help="Verbosity level", action="count")
    argument_parser.add_argument("--version", help="Outputs version number", action="store_true")
    
    args = argument_parser.parse_args()

    version_number= "1.0.0"
    if args.version:
        print(version_number)
    else:
        if not args.INP:
            sys.exit("ERROR: mausbpf2eaf : Missing required input argument --INP. Please specify a valid BAS partitur file to be converted.")
        is_valid_par_file(args.INP)
        #valid_audio_ext = ["wav", "mp3", "flac"]
        orig_wav_path = os.path.splitext(os.path.basename(args.INP))[0] + ".wav"
        
        out_file = os.path.splitext(orig_wav_path)[0] + ".eaf"
        if args.OUT:
            if is_valid_output_file(args.OUT):
                out_file = args.OUT
        else:
            sys.stderr.write("WARNING: mausbpf2eaf : No output file specified. Writing eaf to default location: '%s'\n" %os.path.abspath(out_file))


        parParser = ParParser(args.INP)
        data = parParser.read_file(args.verbose)
        write_elan_data(data, out_file, orig_wav_path, args.verbose)
                             



