#!/usr/bin/env python3

# TODO:
# use re.compile for sep1 and sep2
# check all instances of 'pass' and 'oh no'
# ExbTier.events should probably be a list of dictionaries rather than of tuples
# unify attribute and method names for ParFile <-> ExbFile as well as ParTier <-> ExbTier
# format with black
# maybe use filename conversion after all instead of printing to STDOUT when args.output is not provided
# change par_to_exb so that it accounts for tier boundary distances of size args.leeway even across tiers_class4
# or alternatively, if that's too difficult: always subtract one from start_time during the tier loop


import argparse
from os import path
import sys
import re
from xml.etree.ElementTree import *
from xml.dom import minidom


class ParFile:
	
	def __init__(self, inputstr = None):
		self.header = {}
		self.raw_tiers = {}
		self.tiers = []
		if inputstr:
			self.parse(inputstr)
	
	def parse(self, inputstr):
		
		lines = (line for line in inputstr.splitlines()) # creates a generator
		sep1 = ":[ \t]+"
			# regex that matches the separator between labels and their values, such as "REP: Muenchen"
		sep2 = "[ \t]+"
			# regex that matches the separator between intra-tier values, such as "108885	2056	1	d_s"
		
		# extract and parse header
		for line in lines:
			if line != "LBD:": # the LBD line signifies the end of the header, hence the break statement below
				label, value = re.split(sep1, line, maxsplit = 1)
				self.header[label] = value
			else:
				break
		
		# extract and parse tiers
		for line in lines:
			# since lines is a generator object, this loop won't start over
			# at the first line, but rather after we broke out of the last loop
			line_tiername, line_values = re.split(sep1, line, maxsplit = 1)
			line_values = re.split(sep2, line_values)
			try:
				self.raw_tiers[line_tiername].append(line_values)
			except KeyError:
				# if the dictionary entry for this tier doesn't exist yet,
				# create it and insert the first value 
				self.raw_tiers[line_tiername] = [line_values]
		
		for tier_name in self.raw_tiers:
			self.tiers.append(ParTier(tier_name).parse(self.raw_tiers[tier_name]))
		
		del self.raw_tiers
	
	@classmethod
	def fromFilepath(cls, filepath):
		with open(filepath) as file:
			return cls(file.read())
	
	def get_header_element(self, label):
		return self.header[label]
	
	def set_header_element(self, label, value):
		self.header[label] = value
		# this should be modified so that it inserts the element in the correct position as well
	
	def in_header(self, label):
		return label in self.header
	
	def samples_to_seconds(self, sample):
		return sample / int(self.get_header_element("SAM"))
	
	def remove_tier(self, name):
		for tier in self.tiers:
			if tier.name == name:
				self.tiers.remove(tier)
	
	def get_link_times(self, link):
		tiers_class4 = [tier for tier in self.tiers if tier.tier_class == 4]
		if tiers_class4:
			start_times = []
			end_times = []
			for tier in tiers_class4:
				for item in tier.tier_items:
					if item["linktype"] == "single":
						if item["link"] == link:
							start_times.append(item["start"])
							end_times.append(item["start"] + item["duration"])
#					elif min(item["link"]) == link or max(item["link"]) == link: # this is to be used if you use other linktypes to infer link times
			if start_times and end_times:
				return (min(start_times), max(end_times))
			else:
				raise Exception("cannot_infer_link_time")
		else:
			raise Exception("no_class_4_tiers")


class ParTier:
	
	def __init__(self, name):
		self.name = name
		self.tier_items = []
		self.tier_classes = {"KAN": 1, "KSS": 1, "MRP": 1, "KAS": 1, "PTR": 1, "ORT": 1, "TRL": 1, "TR2": 1, "TRO": 1, "SUP": 1, "DAS": 1, "PRS": 1, "NOI": 1, "PRO": 1, "SYN": 1, "FUN": 1, "LEX": 1, "POS": 1, "LMA": 1, "TRS": 1, "TLN": 1, "TRW": 1, "SPK": 1, "IPA": 2, "GES": 2, "USH": 2, "USM": 2, "OCC": 2, "SPD": 2, "LBP": 3, "LBG": 3, "PRM": 3, "PHO": 4, "SAP": 4, "MAU": 4, "WOR": 4, "TRN": 4, "USP": 4, "MAS": 4, "PRB": 5}
	
	def parse(self, content):
		for item in content:
			self.tier_items.append({})
			if self.tier_class in (2, 4):
				self.tier_items[-1]["start"] = int(item.pop(0))
				self.tier_items[-1]["duration"] = int(item.pop(0))
			if self.tier_class in (3, 5):
				self.tier_items[-1]["time"] = int(item.pop(0))
			if self.tier_class in (1, 4, 5):
				link = item.pop(0)
				if "," in link:
					link = [int(num) for num in link.split(",")]
					linktype = "multiple"
				elif ";" in link:
					link = [int(num) for num in link.split(";")]
					linktype = "between"
				else:
					link = int(link)
					linktype = "single"
				self.tier_items[-1]["link"] = link
				self.tier_items[-1]["linktype"] = linktype
			self.tier_items[-1]["content"] = " ".join(item)
		return self
	
	@property
	def tier_class(self):
		return self.tier_classes[self.name]


class ExbFile:
	
	def __init__(self):
		self.timeline = []
		self.tiers = []
		self.times = []
		self.template = """
<basic-transcription>
<head>
<meta-information>
<project-name></project-name>
<transcription-name></transcription-name>
<referenced-file url=""/>
<ud-meta-information></ud-meta-information>
<comment></comment>
<transcription-convention></transcription-convention>
</meta-information>
<speakertable>
<speaker id="">
<abbreviation></abbreviation>
<sex value=""/>
<languages-used/>
<l1/>
<l2/>
<ud-speaker-information></ud-speaker-information>
<comment></comment>
</speaker>
</speakertable>
</head>
<basic-body>
<common-timeline></common-timeline>
</basic-body>
</basic-transcription>
		"""
		# ^ this really ought to be prettier
		self.tree = ElementTree(fromstring(self.template))
	
	def set_header_element(self, element, content):
		next(self.tree.iter(element)).text = content
	
	def set_header_attribute(self, element, attribute, value):
		next(self.tree.iter(element)).set(attribute, value)
	
	def add_tier(self, *args, **kwargs):
		new_tier = ExbTier(*args, **kwargs)
		self.tiers.append(new_tier)
		return new_tier
	
	def remove_tier(self, tier):
		self.tiers.remove(tier)
	
	def construct_common_timeline(self):
		common_timeline = next(self.tree.iter("common-timeline")) # this should be retrieved in some more elegant way
		for tier in self.tiers:
			for event in tier.events:
				self.times.append(event["start_time"])
				self.times.append(event["end_time"])
		self.times = list(set(self.times)) # removes duplicate values
		self.times.sort()
		###
		"""
		foo = []
		for time in exbFile.times:
			foo.append([othertime for othertime in exbFile.times if othertime < time and othertime >= othertime - 5])
		
		previous_time = 0
		for time in times:
			for othertime in times:
				if othertime >= time:
					break
				elif othertime >= (time - leeway):
					pass # do sth
			
			if time 
			previous_time = time
		"""
		###
		for i, time in enumerate(self.times):
			ID = "T%d" % i
			for tier in self.tiers:
				for event in tier.events:
					if time == event["start_time"]:
						event["start_ID"] = ID
					if time == event["end_time"]:
						event["end_ID"] = ID
			tli = Element("tli", {
				"id": ID,
				"time": str(time)
			})
			common_timeline.append(tli)
	
	def construct_tiers(self):
		basic_body = next(self.tree.iter("basic-body")) # this should be retrieved in some more elegant way
		for i, tier in enumerate(self.tiers):
			tier_element = Element("tier", {
				"id": "TIE%d" % i,
				"speaker": tier.speaker,
				"category": tier.category,
				"display-name": tier.display_name,
				"type": tier.tier_type
			})
			basic_body.append(tier_element)
			for event in tier.events:
				event_start = self.times.index(event["start_time"])
				event_end = self.times.index(event["end_time"])
				event_element = Element("event", {
					"start": event["start_ID"],
					"end": event["end_ID"],
				})
				event_element.text = event["content"]
				tier_element.append(event_element)
	
	def generate(self):
		self.construct_common_timeline()
		self.construct_tiers()
		outputstr = minidom.parseString(tostring(self.tree.getroot())).toprettyxml()
		outputstr = re.sub("\n+", "\n", outputstr)
		return outputstr


class ExbTier:
	
	def __init__(self, category, tier_type, display_name = "", speaker = None):
		self.category = category
		self.tier_type = tier_type
		self.display_name = display_name
		self.speaker = speaker
		self.events = []
	
	def add_event(self, start_time, end_time, content):
		self.events.append({
			"start_ID": None,
			"end_ID": None,
			"start_time": start_time,
			"end_time": end_time,
			"content": content
		})


def par_to_exb(par, exb, referenced_files = None, leeway = None):
	"""
	tiers_class4 = [tier for tier in par.tiers if \
		tier.tier_class == 4 and \
		all([tier_item["linktype"] == "single" for tier_item in tier.tier_items])]
	if not tiers_class4:
		echo("oh no: no class 4 tiers with consistently single symbolic links", type = "error")
	"""
	
	# see header-equivalents.txt
	# also this could probably be made much nicer with a dictionary or something
	# instead of a thousand if statements
	if par.in_header("DBN"):
		project_name = par.get_header_element("DBN")
		exb.set_header_element("project-name", project_name)
	
	if referenced_files:
		exb.set_header_attribute("referenced-file", "url", referenced_files[0])
		# this should be changed so that it actually creates multiple
		# "referenced-file" elements for each element of the list passed as a paramter
	elif par.in_header("SRC"):
		referenced_file = par.get_header_element("SRC")
		exb.set_header_attribute("referenced-file", "url", referenced_file)
	
	if par.in_header("CMT"):
		comment = par.get_header_element("CMT")
		exb.set_header_element("comment", comment)
	
	if par.in_header("SPN"): # compulsory element anyway
		speaker_id = par.get_header_element("SPN")
		exb.set_header_attribute("speaker", "id", speaker_id)
	
	exb.set_header_attribute("sex", "value", "u")
		# 'unknown' or 'undefined' since this information is not provided in a .par file
	
	for par_tier in par.tiers:
		if par_tier.tier_class == 3:
			# discard the current tier and continue with the next one
			echo("class 3 tiers are currently not supported. Skipping tier '{}'".format(par_tier.name), type = "warning")
			continue
		exb_tier = exb.add_tier(category = "0", tier_type = "0", display_name = par_tier.name, speaker = speaker_id)
		previous_end_time = 0
		for i, tier_item in enumerate(par_tier.tier_items):
			if par_tier.tier_class == 1 or par_tier.tier_class == 5:
				try:
					if tier_item["linktype"] == "single":
						start, end = par.get_link_times(tier_item["link"])
					elif tier_item["linktype"] == "multiple":
						start, _ = par.get_link_times(min(tier_item["link"]))
						_, end = par.get_link_times(max(tier_item["link"]))
					elif tier_item["linktype"] == "between":
						# discard the current tier and continue with the next one
						echo("cannot currently handle semicolon-separated symbolic links (such as '2;3'). Skipping tier '{}'".format(par_tier.name), type = "warning")
						exb.remove_tier(exb_tier)
						break
				except Exception as exception:
					if str(exception) == "cannot_infer_link_time":
						# discard the current tier and continue with the next one
						echo("could not infer time of symbolic link. Skipping tier {tier}\n\tIf the input file contains any class 1 or class 5 tiers, it must contain at least one class 4 tier whose items all reference precisely one symbolic link (no comma or semicolon-separated list of links), otherwise the time of the symbolic links cannot be inferred. In this case, the time of the symbolic link {link} of the item of index {index} of tier {tier} could not be inferred".format(link = tier_item["link"], index = i, tier = par_tier.name), type = "warning")
						exb.remove_tier(exb_tier)
						break
					elif str(exception) == "no_class_4_tiers":
						# discard the current tier and continue with the next one
						echo("no class 4 tiers could be found. Skipping tier '{}'\n\tIf the input file contains any class 1 or class 5 tiers, it must contain at least one class 4 tier whose items all reference precisely one symbolic link (no comma or semicolon-separated list of links), otherwise the time of the symbolic links cannot be inferred".format(par_tier.name), type = "warning")
						exb.remove_tier(exb_tier)
						break
					else:
						raise exception
			elif par_tier.tier_class == 2 or par_tier.tier_class == 4:
				start = tier_item["start"]
				end = start + tier_item["duration"]
			if start <= previous_end_time + leeway and start >= previous_end_time - leeway:
				start = previous_end_time
#			if start > 0:
#				start -= 1
			previous_end_time = end
			start, end = par.samples_to_seconds(start), par.samples_to_seconds(end)
			exb_tier.add_event(start, end, tier_item["content"])


def echo(*messages, type, required_verbosity = None):
	if type == "debug":
		if args.verbosity == required_verbosity:
			print("DEBUG: mausbpf2exb:", *messages)
	elif type == "debug_multiple":
		print("DEBUG: mausbpf2exb:", messages[args.verbosity])
	elif type == "warning":
		print("WARNING: mausbpf2exb:", *messages, file = sys.stderr)
	elif type == "error":
		sys.exit("ERROR: mausbpf2exb: " + "".join(messages) + " - exiting...")


if __name__ == "__main__":
	
	parser = argparse.ArgumentParser()
	parser.add_argument("input", nargs = "?", help = "input file (.par, .bpf). If not specified, STDIN will be used")
	parser.add_argument("-o", "--output", help = "output file (.exb). If not specified, STDOUT will be used")
	parser.add_argument("-v", "--verbosity", type = int, default = 0, help = "verbosity level on a scale of 0 to 2 (default = 0)")
	parser.add_argument("-r", "--referenced-file", help = "name (or comma-separated list of names) of referenced audio file(s)")
	parser.add_argument("--ignore", dest = "ignored_tier", help = "name (or comma-separated list of names) of BPF tier(s) to be ignored during conversion")
	parser.add_argument("--leeway", type = int, default = 5, help = "minimum distance (in number of samples) which two timeline items in the .exb file have to be apart in order to be regarded as distinct items (default = 5). If below the threshold, the items will be merged.")
	parser.add_argument("--version", action = "store_true", help = "print version number and exit")
	args = parser.parse_args()
	
	if args.version:
		print("0.1")
		sys.exit()
	
	parFile = ParFile.fromFilepath(args.input) if args.input else ParFile(sys.stdin.read())
	exbFile = ExbFile()
	if args.ignored_tier:
		ignored_tiers = args.ignored_tier.split(",")
		for ignored_tier in ignored_tiers:
			parFile.remove_tier(ignored_tier)
	if args.referenced_file:
		args.referenced_file = args.referenced_file.split(",")
	par_to_exb(parFile, exbFile, args.referenced_file, args.leeway)
	outputstr = exbFile.generate()
	if args.output:
		with open(args.output, mode = "w", encoding = "utf-8", newline = "\n") as outputfile:
			outputfile.write(outputstr)
	else:
		print(outputstr)

