#! /usr/bin/python3
import string
import re

TIERS_TO_DELETE = ["TRO", "TRN"]

class Item:
    def __init__(self, index, text):
        self.index = index
        self.text = text

    def __len__(self):
        return len(self.text)

    def __getitem__(self, index):
        return self.text[index]

    def __str__(self):
        return self.text + "_" + str(self.index)

class OrtItem(Item):
    def __init__(self, index, text):
        Item.__init__(self, index, text)

class TransItem(Item):
    def __init__(self, index, text, aligntext=None):
        Item.__init__(self, index, text)
        self.aligntext = self.text if aligntext is None else aligntext

    def __len__(self):
        return len(self.aligntext)

    def __getitem__(self, index):
        return self.aligntext[index]


class TokenList:
    def __init__(self):
        self.tokens = []
    
    def __getitem__(self, index):
        return self.tokens[index]
    
    def sublist(self, start, stop):
        newlist = self.__class__()
        for token in self.tokens[start:stop]:
            newlist.tokens.append(token)
        return newlist

    def __len__(self):
        return len(self.tokens)
        
    def build(self):
        if len(self.tokens) == 0:
            raise Exception("No ORT items found in", bpf)

class OrtList(TokenList):
    def __init__(self):
        self.NUM_RGX = re.compile("-?[0-9]+")
        self.ORTKEY = "ORT:"
        TokenList.__init__(self)
    
    def interpret_ort(self, line):
        splitline = line.strip().split()
        if len(splitline) == 3 \
            and self.NUM_RGX.match(splitline[1]) \
            and splitline[0] == self.ORTKEY:

            index = int(splitline[1])
            text = splitline[2].strip()
            return (index, text)


        raise Exception("Invalid ORT entry:", line)

    def build(self, bpf):
        try:
            handle = open(bpf, "r")
        except:
            raise Exception("Could not open", bpf)
        
        for line in handle:
            if line.startswith(self.ORTKEY):
                self.tokens.append(OrtItem(*self.interpret_ort(line)))

        
        handle.close()


class TransList(TokenList):
    def __init__(self):
        TokenList.__init__(self)

    def escape(self, text):
        text = re.sub("\n", "\\\\n", text)
        text = re.sub("\t", "\\\\t", text)
        text = re.sub("\r", "\\\\r", text)
        text = re.sub(" ", "\\\\s", text)

        return text

    def build(self, transcription, splitpunct):
        try:
            handle = open(transcription, "r")
        except:
            raise Exception("Could not open", transcription)

        text = handle.read()
        handle.close()

        tag_rgx = re.compile("<.*>")
        ws_rgx = re.compile("([" + string.whitespace + "]+)")
        
        allpunct = "[" + string.punctuation + "”“’‘…]"
        splitpunct = "[" + "".join(splitpunct) + "]" 
        punct_rgx = re.compile("(" + allpunct + "*" + splitpunct + "+" + allpunct + "*)")

        splittext = []
        for ws_token in ws_rgx.split(text):
            if tag_rgx.match(ws_token):
                splittext.append(ws_token)

            elif ws_rgx.match(ws_token) and len(splittext):
                splittext[-1] += ws_token

            elif len(ws_token):
                for punct_token in punct_rgx.split(ws_token):
                    if len(punct_token):
                        splittext.append(punct_token)
                        
        splittext = [self.escape(text) for text in splittext]
        self.tokens = [TransItem(index, text, re.sub("\\\\[strn]", "", text)) for index, text in enumerate(splittext)]
        TokenList.build(self)


def subscost(text1, text2):
    if isinstance(text1, str) and isinstance(text2, str) \
            and len(text1) == 1 and len(text2) == 1:
        return int(text1 != text2)

    if len(text1) == 0 or len(text2) == 0:
        return 1
    
    m = Matrix(text1, text2)
    m.build()

    subs = m.matrix[len(text1)-1][len(text2)-1]
    norm = max(len(text1), len(text2))

    result = float(subs) / float(norm)
    
    return result

class PathItem:
    def __init__(self, OrtItem, TransItem):
        self.ort_item = OrtItem
        self.trans_item = TransItem

    def __str__(self):
        return "+".join((str(self.ort_item), str(self.trans_item)))

class Path:

    def __str__(self):
        return "\n".join([str(item) for item in self.path_items])
    
    def __init__(self):
        self.path_items = []

    def append(self, path_item):
        self.path_items.insert(0, path_item)
    
    def extend(self, path):
        for item in path.path_items:
            self.path_items.append(item)

    def half(self):
        new_path = Path()
        for item in self.path_items[:len(self.path_items) // 2 + 1]:
            new_path.path_items.append(item)

        return new_path

    def resolve(self, competitors, which):
        indices = {}
        for comp in competitors:
            item = comp.ort_item if which == "ort" else comp.trans_item
            if not item is None:
                indices.add(item.index)
    
        if len(indices) > 1:
            tmp = [1 if (comp.ort_item is None or comp.trans_item is None) else subscost(comp.ort_item, comp.trans_item) for comp in competitors]
            idx = tmp.index(min(tmp))
            for jdx, comp in enumerate(competitors):
                if jdx != idx:
                    if which == "ort":
                        comp.ort_item = None
                    else:
                        comp.trans_item = None

        competitors.clear()

    def clean(self):
        for which in ("ort", "trans"):
            indices2pathitems = {}
            for path_item in self.path_items:
                item = path_item.ort_item if which == "ort" else path_item.trans_item
                if not item is None:
                    indices2pathitems[item.index] = indices2pathitems.get(item.index, []) + [path_item]

            for index in indices2pathitems:
                competitors = indices2pathitems[index]
                tmp = [1 if (comp.ort_item is None or comp.trans_item is None) else subscost(comp.ort_item, comp.trans_item) for comp in competitors]
                idx = tmp.index(min(tmp))
                for jdx, comp in enumerate(competitors):
                    if jdx != idx:
                        if which == "ort":
                            comp.ort_item = None
                        else:
                            comp.trans_item = None



class Matrix:
    def __init__(self, outer_text, inner_text):
        self.matrix = [[None for _ in range(len(inner_text))] for _ in range(len(outer_text))]
        self.outer_text = outer_text
        self.inner_text = inner_text

    def build(self):

        if len(self.inner_text) == 0 or len(self.outer_text) == 0:
            return

        self.matrix[0][0] = subscost(self.outer_text[0], self.inner_text[0])

        for idx in range(1, len(self.matrix)):
            subs = subscost(self.outer_text[idx], self.inner_text[0])
            self.matrix[idx][0] = min((subs + idx, self.matrix[idx-1][0] + 1))
        
        for jdx in range(1, len(self.matrix[0])):
            subs = subscost(self.outer_text[0], self.inner_text[jdx])
            self.matrix[0][jdx] = min((subs + jdx, self.matrix[0][jdx-1] + 1))

        for idx in range(1, len(self.matrix)):
            for jdx in range(1, len(self.matrix[0])):
                subs = subscost(self.outer_text[idx], self.inner_text[jdx])
                self.matrix[idx][jdx] = min((self.matrix[idx-1][jdx-1] + subs, 
                    self.matrix[idx-1][jdx] + 1, self.matrix[idx][jdx-1] + 1))

    def find_endpoint(self, fixed_outer, fixed_inner):
        if fixed_outer:
            idx = len(self.matrix) - 1
        else:
            tmp = [outer[-1] for outer in self.matrix]
            min_cost = min(tmp)
            idx = tmp.index(min_cost)
        if fixed_inner:
            jdx = len(self.matrix[0]) - 1
        else:
            min_cost = min(self.matrix[-1])
            jdx = self.matrix[-1].index(min_cost)
        return (idx, jdx)

    def find_path(self, fixed_outer, fixed_inner):
        path = Path()

        if len(self.outer_text) == 0 or len(self.inner_text) == 0:
            for idx in range(len(self.outer_text)):
                path.append(PathItem(self.outer_text[idx], None))
            for jdx in range(len(self.inner_text)):
                path.append(PathItem(None, self.inner_text[jdx]))

            return path

        idx, jdx = self.find_endpoint(fixed_outer, fixed_inner)

        while True:
            path.append(PathItem(self.outer_text[idx], self.inner_text[jdx]))
            if idx == 0 and jdx == 0:
                break
            elif idx == 0:
                jdx -= 1
            elif jdx == 0:
                idx -= 1
            else:
                diag = self.matrix[idx-1][jdx-1]
                left = self.matrix[idx-1][jdx]
                right = self.matrix[idx][jdx-1]

                if diag <= left and diag <= right:
                    idx-=1
                    jdx-=1
                elif left <= right:
                    idx-=1
                else:
                    jdx-=1

        return path


class Aligner:
    def __init__(self, verbose, windowsize):
        self.ort_list = OrtList()
        self.trans_list = TransList()
        self.windowsize = windowsize
        self.path = Path()
        self.verbose = verbose

    def build(self, bpffile, textfile, splitpunct):
        self.ort_list.build(bpffile)
        self.trans_list.build(textfile, splitpunct)

    def align(self):
        idx_left = 0
        jdx_left = 0
        
        while True:

            if self.verbose:
                print(idx_left, "/", jdx_left, " (", len(self.ort_list), "/", len(self.trans_list), ")\r", sep = "", end = "")

            idx_right = min(idx_left + self.windowsize, len(self.ort_list))
            jdx_right = min(jdx_left + self.windowsize, len(self.trans_list))

            matrix = Matrix(self.ort_list.sublist(idx_left, idx_right), \
                    self.trans_list.sublist(jdx_left, jdx_right))
            
            matrix.build()
            
            path = matrix.find_path(idx_right == len(self.ort_list), jdx_right == len(self.trans_list))
            
            if idx_right == len(self.ort_list) and jdx_right == len(self.trans_list):
                self.path.extend(path)
                break
            
            half_path = path.half()
            self.path.extend(half_path)

            idx_left, jdx_left = None, None 
            i = 0
            while idx_left is None:
                i-=1
                if self.path.path_items[i].ort_item is not None:
                    idx_left = self.path.path_items[i].ort_item.index + 1
            
            i = 0
            while jdx_left is None:
                i-=1
                if self.path.path_items[i].trans_item is not None:
                    jdx_left = self.path.path_items[i].trans_item.index + 1
            
            assert not idx_left is None and not jdx_left is None

        self.path.clean()
        
        if self.verbose:
                print(idx_right, "/", jdx_right, " (", len(self.ort_list), "/", len(self.trans_list), ")\r", sep = "")
             
    def write_tro_tier(self, handle):
        for item in self.path.path_items:
            if item.trans_item is None:
                continue

            label = item.trans_item.text
            
            if item.ort_item is None:
                index = -1
            else:
                index = item.ort_item.index

            handle.write(" ".join(("TRO:", str(index), label)) + "\n")

def main(bpffile, transcriptionfile, outfile, verbose, windowsize, splitpunct):
    
    aligner = Aligner(verbose = verbose, windowsize = windowsize)
    aligner.build(bpffile, transcriptionfile, splitpunct = splitpunct)
    aligner.align()

    with open(bpffile, "r") as handle:
        bpfstring = handle.read()
        if not bpfstring.endswith("\n"):
            bpfstring += "\n"

    ###### MAKE SURE THAT tiers in TIERS_TO_DELETE from the input files are not written to the file
    bpfstring_cleaned = ""
    #iterate over all lines
    for currLine in bpfstring.splitlines():
        skipline = False
        #check if any of the defined tiers match on the current line
        for currTier in TIERS_TO_DELETE:
            if currTier in currLine:
                skipline = True
        #if one of the defined tiers have matched, skip the line, otherwise add
        if not skipline:
            bpfstring_cleaned += currLine +"\n"

    bpfstring = bpfstring_cleaned 
    ###### 

    with open(outfile, "w") as whandle:
        whandle.write(bpfstring)
        aligner.write_tro_tier(whandle)    
