#! /usr/bin/python3

import re
import sys
import datetime


def get_samplerate(bpf):
    for line in open(bpf):
        if line.startswith("SAM"):
            try:
                return float(line.strip().split()[1])
            except:
                raise Exception("Not a valid sample rate entry: ", line)

    raise Exception("No valid sample rate tag in ", bpf)

def read_metadata(bpf):
    with open(bpf, "r") as handle:
        #This assumes that the header is within the first 100 lines, but avoids reading the whole file again
        #data = [next(in_file) for x in range(100)][0]
        data = handle.read()
        bpfheader = re.search("^(.*)\nLBD:\n", data, re.DOTALL).group()
    return bpfheader

def read_mautier(bpf):
    maudict = {}
    with open(bpf) as handle:
        for line in handle:
            if line.startswith("MAU"):
                splitline = line.split()
                idx = int(splitline[3])
                start = int(splitline[1])
                end = start + int(splitline[2])

                if not idx in maudict:
                    maudict[idx] = {"start": start, "end": end}

                else:
                    if start < maudict[idx]["start"]:
                        maudict[idx]["start"] = start
                    if end > maudict[idx]["end"]:
                        maudict[idx]["end"] = end

    return maudict

def pad(num, l):
    num = str(num)
    return "0" * (l-len(num)) + num
    
def formsrt(milliseconds):
    milliseconds = int(milliseconds)
    s, msec = divmod(milliseconds, 1000)
    m, s = divmod(s, 60)
    h, m = divmod(m, 60)

    return pad(h, 2) + ":" + pad(m, 2) + ":" + pad(s, 2) + "," + pad(msec, 3)

def formvtt(milliseconds):
    milliseconds = int(milliseconds)
    s, msec = divmod(milliseconds, 1000)
    m, s = divmod(s, 60)
    h, m = divmod(m, 60)

    return pad(h, 2) + ":" + pad(m, 2) + ":" + pad(s, 2) + "." + pad(msec, 3)
    
def formsub(centiseconds):
    centiseconds = int(centiseconds)
    s, msec = divmod(centiseconds, 100)
    m, s = divmod(s, 60)
    h, m = divmod(m, 60)

    return pad(h, 2) + ":" + pad(m, 2) + ":" + pad(s, 2) + "," + pad(msec, 2)


        
def output(out, typ, subtitles, samplerate, bpf, origbpf):
    with open(out, "w") as whandle:
        if typ == "srt":
            for i, subtitle in enumerate(subtitles):
                whandle.write(str(i+1) + "\n")
                whandle.write(formsrt(subtitle["start"] / samplerate * 1000) + " --> ")
                whandle.write(formsrt(subtitle["end"] / samplerate * 1000) + "\n")
                whandle.write(subtitle["label"].strip() + "\n\n")
    
        elif typ == "sub":
            whandle.write("[INFORMATION]\n[SUBTITLE]\n")
            for i, subtitle in enumerate(subtitles):
                whandle.write(formsub(subtitle["start"] / samplerate * 100) + ",")
                whandle.write(formsub(subtitle["end"] / samplerate * 100) + "\n")
                whandle.write(subtitle["label"].strip() + "\n\n")

        elif typ == "vtt":
            bpfheader = read_metadata(origbpf)
            whandle.write("WEBVTT\n\n")
            whandle.write("NOTE "+ bpfheader.strip().replace("\n", "; ") + "\n\n")
            for i, subtitle in enumerate(subtitles):
                whandle.write(formvtt(subtitle["start"] / samplerate * 1000) + " --> ")
                whandle.write(formvtt(subtitle["end"] / samplerate * 1000) + "\n")
                whandle.write(subtitle["label"].strip() + "\n\n")
    
        elif typ.startswith("bpf"):
            with open(bpf, "r") as rhandle:
                for line in rhandle:
                    if (not line.startswith("TRN:")) or (typ == "bpf"):
                        whandle.write(line)

        if typ == "bpf+trn" or typ == "trn":
            for i, subtitle in enumerate(subtitles):
                whandle.write(" ".join(("TRN:", str(int(subtitle["start"])), 
                    str(int(subtitle["end"] - subtitle["start"] - 1)), 
                    ",".join([str(idx) for idx in subtitle["indices"] if idx >= 0]),
                    subtitle["label"])) + "\n")




def unescape(label, unescape_tab_and_newline=False):
    label = re.sub("\\\\s", " ", label)
    if unescape_tab_and_newline:
        label = re.sub("\\\\r", "\r", label)
        label = re.sub("\\\\t", "\t", label)
        label = re.sub("\\\\n", "\n", label)
        label = re.sub("[\t\n\r]+", " ", label)

    return label
                        

def warn(warning):
    print("WARNING:", warning, file = sys.stderr, flush = True)


def make_subtitle(indices, totallabel, maudict, marker, tag_markers, outformat, tier):
    if tier == "ORT":
        label = " ".join(totallabel)
    else:
        label = "".join(totallabel)
        label = unescape(label, unescape_tab_and_newline = outformat in ("srt", "sub", "vtt"))
    
    if marker == "tag":
        for tag_marker in tag_markers:
            label = re.sub("\s*" + tag_marker + "\s*", " ", label)

    label = label.strip()

    if all([index == -1 for index in indices]):
        warning = "All tokens in this subtitle are linked to '-1' (unlinked): " + label
        warn(warning)
        return {"label": label, "indices": None}

    pot_start = [maudict[index]["start"] for index in indices if index in maudict and index >= 0]
    pot_end = [maudict[index]["end"] for index in indices if index in maudict and index >= 0]

    if len(pot_start) == 0 or len(pot_end) == 0:
        warning = " ".join(("Could not find MAU segments for the following subtitle: ", 
            label, "[", ",".join(indices), "]"))
        warn(warning)
        return {"label": label, "indices": None}

    start = min(pot_start)
    end = max(pot_end)
    
    return {"start": start, "end": end, "label": label, 
        "indices": sorted(list(indices))}

def main(bpf, origbpf, outfile, marker, outformat, maxlength, tier):
    samplerate = get_samplerate(bpf)

    if maxlength <= 0:
        maxlength = float('inf')

    if tier == "ORT":
        warning = "Tier is ORT. Expecting <BREAK> tag in ORT tier to delimit subtitles. Setting '--marker tag' and ignoring transcription."
        warn(warning)
        marker = "tag"
    
        if not outformat in ("sub", "srt", "vtt"):
            raise Exception("Tier is ORT. Outformat must be 'sub', 'srt' or 'vtt'.")
    
    MARKERS = {"tag": ["<BREAK>"], "newline": ["\\n"], "punct": [".", "!", ":", "?", "…"]}

    markers = MARKERS[marker]

    maudict = read_mautier(bpf)

    subtitles = []

    found_markers = 0
    over_long = 0

    tagrgx = re.compile(".*<.*[" + "".join(MARKERS["punct"]) + "].*>.*")
    
    found_tier = False
    with open(bpf) as handle:
            totallabel = []
            indices = set()

            for line in handle:
                if line.startswith(tier):
                    found_tier = True
                    splitline = line.strip().split()
                    
                    if len(splitline) < 3: continue

                    label = splitline[2]
                    totallabel.append(label)

                    idx = int(splitline[1])

                    indices.add(idx)

                    found_markers += any([indicator in label for indicator in markers])
                    over_long += len(totallabel) >= maxlength

                    if (any([indicator in label for indicator in markers]) and not (marker == "punct" and tagrgx.match(label))) \
                            or len(totallabel) >= maxlength:

                        sub = make_subtitle(indices, totallabel, maudict, marker, MARKERS["tag"], outformat, tier = tier)
                        if not sub["indices"] is None: 
                            subtitles.append(sub)

                        elif len(subtitles):
                            subtitles[-1]["label"] += sub["label"]
                        
                        totallabel = []
                        indices = set()
            
            if len(indices) > 0 and len(totallabel) > 0:

                sub = make_subtitle(indices, totallabel, maudict, marker, MARKERS["tag"], outformat, tier = tier)
                if not sub["indices"] is None: 
                    subtitles.append(sub)
                
                elif len(subtitles):
                    subtitles[-1]["label"] += sub["label"]

    subtitles = [x for x in subtitles if len(x["label"])]

    if not found_tier:
        raise Exception("Did not find " + tier + " tier in " + bpf)

    if found_markers == 0:
        warning = "Did not find a single substring separation marker (--marker {}). You may want to check your input.".format(marker)
        warn(warning)
        
    if over_long != 0:
        warning = "{} subtitles were split by reaching maximum length {}.".format(over_long, maxlength)
        warn(warning)

    output(outfile, outformat, subtitles, samplerate, bpf, origbpf)
