#!/usr/bin/env python3

###### mausbpf2exb
#	last edit: 2019-11-19
#	author: David Huss
#	email: david.huss@phonetik.uni-muenchen.de
#	For more information, consult README.md or run
#		`python3 mausbpf2exb.py --help`
######


# TODO:
# use re.compile for sep1 and sep2
# ExbTier.events should probably be a list of dictionaries rather than of tuples
# unify attribute and method names for ParFile <-> ExbFile as well as ParTier <-> ExbTier
# add warning if user chooses to ignore a tier but the tier isn't present anyway
# add error if there's an unknown BPF tier
# format with black
# add missing docstrings
# fix docstring for ParTier.parse()
# document public attributes of classes as well
# maybe use filename conversion after all instead of printing to STDOUT when args.output is not provided
# change par_to_exb so that it accounts for tier boundary distances of size args.leeway even across tiers_class4
# or alternatively, if that's too difficult: always subtract one from start_time during the tier loop


import argparse
from os import path
import sys
import re
from collections import OrderedDict
from xml.etree.ElementTree import *
from xml.dom import minidom


class ParFile:
	"""
	A class which represents/encapsulates files in the BAS Partitur Format.
	
	Arguments:
		inputstr (str, optional): a string with the contents of a BPF file.
		If this argument is provided, the method parse() will automatically be called.
		An alternative way to instantiate the class is to use the class method fromFilepath() -
		this method expects a filename and will automatically retrieve the contents of the file and parse them.
	"""
	def __init__(self, inputstr = None):
		"""See class docstring."""
		self.header = {}
		self.raw_tiers = OrderedDict()
		self.tiers = []
		if inputstr:
			self.parse(inputstr)
	
	def parse(self, inputstr):
		"""
		Take the contents of a BPF file, parse it, and encapsulate it in the classes properties and methods.
		
		Arguments:
			inputstr (str): The contents of the BPF file.
		"""
		lines = (line for line in inputstr.splitlines()) # creates a generator
		sep1 = ":[ \t]+"
			# regex that matches the separator between labels and their values, such as "REP: Muenchen"
		sep2 = "[ \t]+"
			# regex that matches the separator between intra-tier values, such as "108885	2056	1	d_s"
		
		# extract and parse header
		for line in lines:
			if line != "LBD:": # the LBD line signifies the end of the header, hence the break statement below
				label, value = re.split(sep1, line, maxsplit = 1)
				self.header[label] = value
			else:
				break
		
		# extract and parse tiers
		for line in lines:
			# since lines is a generator object, this loop won't start over
			# at the first line, but rather after we broke out of the last loop
			line_tiername, line_values = re.split(sep1, line, maxsplit = 1)
			line_values = re.split(sep2, line_values)
			try:
				self.raw_tiers[line_tiername].append(line_values)
			except KeyError:
				# if the dictionary entry for this tier doesn't exist yet,
				# create it and insert the first value 
				self.raw_tiers[line_tiername] = [line_values]
		
		for tier_name in self.raw_tiers:
			self.tiers.append(ParTier(tier_name).parse(self.raw_tiers[tier_name]))
		
		del self.raw_tiers
	
	@classmethod
	def fromFilepath(cls, filepath):
		"""
		Open a BPF file and pass its contents to the class constructor, returning an instance of the class.
		
		Arguments:
			filepath (str): The name of the BPF file.
		"""
		with open(filepath) as file:
			return cls(file.read())
	
	def get_header_element(self, label):
		"""
		Return an element from the BPF header.
		
		Arguments:
			label (str): The name of the header element.
		"""
		return self.header[label]
	
	def set_header_element(self, label, value):
		self.header[label] = value
		# this should be modified so that it inserts the element in the correct position as well
	
	def in_header(self, label):
		"""
		Return a boolean indicating whether the specified element is present in the BPF header.
		
		Arguments:
			label (str): The name of the header element.
		"""
		return label in self.header
	
	def samples_to_seconds(self, sample):
		"""
		Convert seconds to samples in accordance with the BPF file's sampling rate.
		
		Arguments:
			sample (int): Number of samples.
		
		Returns (int): Number of seconds.
		"""
		return sample / int(self.get_header_element("SAM"))
	
	def remove_tier(self, name):
		"""
		Remove a BPF tier belonging to the BPF file object.
		
		Arguments:
			name (str): The name of the tier.
		"""
		for tier in self.tiers:
			if tier.name == name:
				self.tiers.remove(tier)
	
	def get_tier(self, name):
		"""
		Return a BPF tier object (which belongs to the BPF file object).
		
		Arguments:
			name (str): The name of the tier.
		"""
		for tier in self.tiers:
			if tier.name == name:
				return tier
		else:
			# exceptions with underscore syntax are for 'internal use'
			# (issued in classes, then caught in the `if __name__ == "__main__":` section,
			# where they cause errors to be printed to stderr)
			raise Exception("tier_not_found")
	
	@staticmethod
	def get_reference_tiers():
		"""
		Return the 'hierarchy' of reference tiers used for symbolic link inference.
		For more information have a look at the method get_link_time().
		"""
		return ["MAU", "SAP", "WOR", "PHO", "MAS"]
	
	def get_link_times(self, link):
		reference_tiers = ParFile.get_reference_tiers()
		for reference_tier in reference_tiers:
			try:
				reference_tier = self.get_tier(reference_tier)
				start_times = []
				end_times = []
				for item in reference_tier.tier_items:
					if item["linktype"] == "single":
						if item["link"] == link:
							start_times.append(item["start"])
							end_times.append(item["start"] + item["duration"])
#					elif min(item["link"]) == link or max(item["link"]) == link: # this is to be used if you use other linktypes to infer link times
				if start_times and end_times:
					return (min(start_times), max(end_times))
				else:
					# exceptions with underscore syntax are for 'internal use'
					# (issued in classes, then caught in the `if __name__ == "__main__":` section,
					# where they cause errors to be printed to stderr)
					raise Exception("cannot_infer_link_time")
			except Exception as exception:
				if str(exception) == "tier_not_found":
					continue
				else:
					raise exception
		# exceptions with underscore syntax are for 'internal use'
		# (issued in classes, then caught in the `if __name__ == "__main__":` section,
		# where they cause errors to be printed to stderr)
		raise Exception("no_valid_reference_tiers")


class ParTier:
	"""
	A class which represents/encapsulates tiers of a BPF file.
	
	Arguments:
		name (str): The name of the tier.
	"""
	def __init__(self, name):
		self.name = name
		self.tier_items = []
		self.tier_classes = {"KAN": 1, "KSS": 1, "MRP": 1, "KAS": 1, "PTR": 1, "ORT": 1, "TRL": 1, "TR2": 1, "TRO": 1, "SUP": 1, "DAS": 1, "PRS": 1, "NOI": 1, "PRO": 1, "SYN": 1, "FUN": 1, "LEX": 1, "POS": 1, "LMA": 1, "TRS": 1, "TLN": 1, "TRW": 1, "SPK": 1, "IPA": 2, "GES": 2, "USH": 2, "USM": 2, "OCC": 2, "SPD": 2, "LBP": 3, "LBG": 3, "PRM": 3, "PHO": 4, "SAP": 4, "MAU": 4, "WOR": 4, "TRN": 4, "USP": 4, "MAS": 4, "PRB": 5}
	
	def parse(self, content):
		"""
		--- INCOMPLETE DOCSTRING ---
		
		Parse the raw text contents of a BPF tier (as taken from the file) into a list of ordered dictionaries,
		where each entry represents one item of the tier.
		
		Arguments:
			content ()
		"""
		for item in content:
			self.tier_items.append({})
			if self.tier_class in (2, 4):
				self.tier_items[-1]["start"] = int(item.pop(0))
				self.tier_items[-1]["duration"] = int(item.pop(0))
			if self.tier_class in (3, 5):
				self.tier_items[-1]["time"] = int(item.pop(0))
			if self.tier_class in (1, 4, 5):
				link = item.pop(0)
				if "," in link:
					link = [int(num) for num in link.split(",")]
					linktype = "multiple"
				elif ";" in link:
					link = [int(num) for num in link.split(";")]
					linktype = "between"
				else:
					link = int(link)
					linktype = "single"
				self.tier_items[-1]["link"] = link
				self.tier_items[-1]["linktype"] = linktype
			self.tier_items[-1]["content"] = " ".join(item)
		return self
	
	@property
	def tier_class(self):
		"""
		Return the tier class.
		"""
		return self.tier_classes[self.name]


class ExbFile:
	"""
	A class which represents/encapsulates files in the EXMARaLDA Partitur-Editor format (.exb).
	"""
	def __init__(self):
		self.timeline = []
		self.tiers = []
		self.times = []
		self.template = """
<basic-transcription>
<head>
<meta-information>
<project-name></project-name>
<transcription-name></transcription-name>
<referenced-file url=""/>
<ud-meta-information></ud-meta-information>
<comment></comment>
<transcription-convention></transcription-convention>
</meta-information>
<speakertable>
<speaker id="">
<abbreviation></abbreviation>
<sex value=""/>
<languages-used/>
<l1/>
<l2/>
<ud-speaker-information></ud-speaker-information>
<comment></comment>
</speaker>
</speakertable>
</head>
<basic-body>
<common-timeline></common-timeline>
</basic-body>
</basic-transcription>
		"""
		# ^ this really ought to be prettier
		self.tree = ElementTree(fromstring(self.template))
	
	def set_header_element(self, element, content):
		"""
		Set the content of an element in <head> section of the .exb file object.
		
		Arguments:
			element (str): The name of the header element.
			content (str): The text to be inserted between the opening and closing tag of the header element.
		"""
		next(self.tree.iter(element)).text = content
	
	def set_header_attribute(self, element, attribute, value):
		"""
		Set an attribute of an element in the <head> section of the .exb file object.
		If the attribute does not exist yet, create it.
		
		Arguments:
			element (str): The name of the header element.
			attribute (str): The name of the element's attribute to be modified.
			value (str): The value of the attribute.
		"""
		try:
			next(self.tree.iter(element)).set(attribute, value)
		except StopIteration:
			new_element = Element(element, {attribute: value})
			next(self.tree.iter("head")).append(new_element)
	
	def add_tier(self, *args, **kwargs):
		"""
		A sort of wrapper function which instantiates the ExbTier class and adds a reference to said instance the list self.tiers.
		Globally, the ExbTier class should generally not need to be instantiated through any means other than this function.
		
		Arguments:
			Same as for ExbTier's constructor (see that class's docstring).
		"""
		new_tier = ExbTier(*args, **kwargs)
		self.tiers.append(new_tier)
		return new_tier
	
	def remove_tier(self, tier):
		"""
		Remove an ExbTier object belonging to the ExbFile object.
		
		Arguments:
			name (ExbTier): A reference to the tier instance.
		"""
		self.tiers.remove(tier)
	
	def construct_common_timeline(self):
		common_timeline = next(self.tree.iter("common-timeline")) # this should be retrieved in some more elegant way
		for tier in self.tiers:
			for event in tier.events:
				self.times.append(event["start_time"])
				self.times.append(event["end_time"])
		self.times = list(set(self.times)) # removes duplicate values
		self.times.sort()
		for i, time in enumerate(self.times):
			ID = "T%d" % i
			for tier in self.tiers:
				for event in tier.events:
					if time == event["start_time"]:
						event["start_ID"] = ID
					if time == event["end_time"]:
						event["end_ID"] = ID
			tli = Element("tli", {
				"id": ID,
				"time": str(time)
			})
			common_timeline.append(tli)
	
	def construct_tiers(self):
		basic_body = next(self.tree.iter("basic-body")) # this should be retrieved in some more elegant way
		for i, tier in enumerate(self.tiers):
			tier_element = Element("tier", {
				"id": "TIE%d" % i,
				"speaker": tier.speaker,
				"category": tier.category,
				"display-name": tier.display_name,
				"type": tier.tier_type
			})
			basic_body.append(tier_element)
			for event in tier.events:
				event_start = self.times.index(event["start_time"])
				event_end = self.times.index(event["end_time"])
				event_element = Element("event", {
					"start": event["start_ID"],
					"end": event["end_ID"],
				})
				event_element.text = event["content"]
				tier_element.append(event_element)
	
	def generate(self):
		"""
		Generate the XML output of the current state of the ExbFile object.
		
		Returns (str): The output string, ready for writing to file.
		"""
		self.construct_common_timeline()
		self.construct_tiers()
		outputstr = minidom.parseString(tostring(self.tree.getroot())).toprettyxml()
		outputstr = re.sub("\n+", "\n", outputstr)
		return outputstr


class ExbTier:
	"""
	A class which represents/encapsulates tiers of an .exb file.
	
	Arguments:
		# All arguments are XML attributes of the tier element.
		category (str): The category of the tier.
		tier_type (str): The type of the tier.
		display_name (str): The name of the tier which will be displayed prominently in the EXMARaLDA GUI.
		speaker (str): The ID of the speaker of the tier.
	"""
	def __init__(self, category, tier_type, display_name = "", speaker = None):
		self.category = category
		self.tier_type = tier_type
		self.display_name = display_name
		self.speaker = speaker
		self.events = []
	
	def add_event(self, start_time, end_time, content):
		"""
		Add an event to the tier.
		
		Arguments:
			start_time (int, float): The start time of the event in seconds.
			end_time (int, float): The end time of the event in seconds.
			content (str): The text to be inserted between the opening and closing tag of the event element.
		"""
		self.events.append({
			"start_ID": None,
			"end_ID": None,
			"start_time": start_time,
			"end_time": end_time,
			"content": content
		})


def par_to_exb(par, exb, referenced_file = None, leeway = None):
	"""
	Convert a ParFile object's attributes to an ExbFile object's attributes.
	
	Arguments:
		par (ParFile): the input ParFile object.
		exb (ExbFile): the output ExbFile object.
		referenced_file (str, optional): the name of the audio file which the BPF file annotates.
		leeway (int): the maximum number of samples which two item boundaries
			have to be apart in order to be merged (see documentation for --leeway parameter)
			# note: this argument may be syntactically optional,
			# but you should nonetheless always include it.
	"""
	echo("--- Starting header conversion ---\n", required_verbosity = 1)
	
	echo("Setting header attributes", required_verbosity = 1)
	# see header-equivalents.txt
	# also this could probably be made much nicer with a dictionary or something
	# instead of a thousand if statements
	if par.in_header("DBN"):
		project_name = par.get_header_element("DBN")
		exb.set_header_element("project-name", project_name)
		echo("Converted header element: 'DBN' --> <project-name>", required_verbosity = 2)
	
	if referenced_file:
		exb.set_header_attribute("referenced-file", "url", referenced_file)
		echo("Inserted header element passed via command line parameter: -r --> <referenced-file>", required_verbosity = 2)
	elif par.in_header("SRC"):
		referenced_file = par.get_header_element("SRC")
		exb.set_header_attribute("referenced-file", "url", referenced_file)
		echo("Converted header element: 'SRC' --> <referenced-file>", required_verbosity = 2)
	
	if par.in_header("CMT"):
		comment = par.get_header_element("CMT")
		exb.set_header_element("comment", comment)
		echo("Converted header element: 'CMT' --> <comment>", required_verbosity = 2)
	
	if par.in_header("SPN"): # compulsory element anyway
		speaker_id = par.get_header_element("SPN")
		exb.set_header_attribute("speaker", "id", speaker_id)
		echo("Converted header element: 'SPN' --> id (attribute of <speaker>)", required_verbosity = 2)
	
	if par.in_header("SYS"):
		transcription_convention = par.get_header_element("SYS")
		exb.set_header_element("transcription-convention", transcription_convention)
	
	exb.set_header_attribute("sex", "value", "u")
	# 'unknown' or 'undefined' since this information is not provided in a .par file
	
	if args.verbosity > 0:
		print()
	echo("--- Finished header conversion ---\n", required_verbosity = 1)
	
	echo("--- Starting tier conversion ---\n", required_verbosity = 1)
	for par_tier in par.tiers:
		if par_tier.tier_class == 3:
			# discard the current tier and continue with the next one
			echo("Class 3 tiers are currently not supported. Skipping tier '{}'".format(par_tier.name), type = "warning")
			continue
		exb_tier = exb.add_tier(category = "0", tier_type = "0", display_name = par_tier.name, speaker = speaker_id)
		echo(
			"Starting conversion of tier '{}'".format(par_tier.name),
			"Starting conversion of tier '{}' (class {})".format(par_tier.name, par_tier.tier_class),
			type = "debug_multiple"
		)
		if par_tier.tier_class in (1, 5):
			inference_necessary = True
			echo(
				"The tier '{}' is a class {} tier - attempting symbolic link time inference".format(par_tier.name, par_tier.tier_class),
				"The tier '{}' is a class {} tier - it is necessary to infer the time of the symbolic links, which will be accomplished by making use of a selection of class 4 tiers. See --list-reference-tiers for more information".format(par_tier.name, par_tier.tier_class),
				type = "debug_multiple"
			)
		else:
			inference_necessary = False
		previous_end_time = 0
		for i, tier_item in enumerate(par_tier.tier_items):
			if inference_necessary:
				try:
					if tier_item["linktype"] == "single":
						start, end = par.get_link_times(tier_item["link"])
						echo("Performing symbolic link time inference of link '{}'".format(tier_item["link"]), required_verbosity = 2)
					elif tier_item["linktype"] == "multiple":
						start, _ = par.get_link_times(min(tier_item["link"]))
						_, end = par.get_link_times(max(tier_item["link"]))
						echo("Performing symbolic link time inference of link '{}'".format(tier_item["link"]), required_verbosity = 2)
					elif tier_item["linktype"] == "between":
						# discard the current tier and continue with the next one
						echo("Cannot handle semicolon-separated symbolic links (such as '2;3'). Skipping tier '{}'".format(par_tier.name), type = "warning")
						exb.remove_tier(exb_tier)
						break
				except Exception as exception:
					if str(exception) == "cannot_infer_link_time":
						# throw error message and exit
						echo("Could not infer time of symbolic link - exiting...\n\tIf the input file contains any class 1 or class 5 tiers, the program will cycle through a hierarchy of selected class 4 tiers (in the order MAU->SAP->WOR->PHO->MAS) and use the first one it finds to infer the times of the symbolic links. If there is not at least one of these tiers present which contains a reference to every symbolic link (and provided the reference is a single link, not a comma or semicolon-separated list of links), the file cannot be converted. In this case, the time of the symbolic link {link} of the item of index {index} of tier {tier} could not be inferred".format(link = tier_item["link"], index = i, tier = par_tier.name), type = "error")
					elif str(exception) == "no_valid_reference_tiers":
						# throw error message and exit
						echo("No valid reference tiers could be found - exiting...\n\tIf the input file contains any class 1 or class 5 tiers, the program will cycle through a hierarchy of selected class 4 tiers (in the order MAU->SAP->WOR->PHO->MAS) and use the first one it finds to infer the times of the symbolic links. If there is not at least one of these tiers present which contains a reference to every symbolic link (and provided the reference is a single link, not a comma or semicolon-separated list of links), the file cannot be converted", type = "error")
					else:
						raise exception
			elif par_tier.tier_class == 2 or par_tier.tier_class == 4:
				start = tier_item["start"]
				end = start + tier_item["duration"]
			if start <= previous_end_time + leeway and start >= previous_end_time - leeway:
				start = previous_end_time
#				echo("The start time of the current tier item (index {}) has a distance of less than {} samples from the end time of the previous item. They will therefore be merged, as specified by the --leeway parameter".format(i, leeway), required_verbosity = 2)
#			if start > 0:
#				start -= 1
			previous_end_time = end
			start, end = par.samples_to_seconds(start), par.samples_to_seconds(end)
			exb_tier.add_event(start, end, tier_item["content"])
			echo("Converted BPF tier item {} of index {} to .exb tier event".format(repr(tier_item["content"]), i), required_verbosity = 2)
		echo("Added tier '{}'\n".format(par_tier.name), required_verbosity = 1)
	
	echo("--- Finished tier conversion ---\n", required_verbosity = 1)


def echo(*messages, type = "debug", required_verbosity = None):
	"""
	Display debug information, a warning, or an error to the user.
	(to the console - this function does not raise python errors!)
	
	Arguments:
		messages (arbitrary number of str): The strings to be displayed. For all values of type except 'debug_multiple',
			the strings will be joined with spaces. If the value of type is 'debug_multiple', then one message
			should be provided for each verbosity level, in ascending order. Example usage:
			`echo('message for verbosity = 0', 'message for verbosity = 1', 'message for verbosity = 2', type = 'debug_multiple')`
		type (str): either 'debug', 'debug_multiple', 'warning', or 'error'.
		required_verbosity (int): if the value of type is 'debug', then this is the minimum required verbosity
			necessary for the messages to be displayed.
	"""
	if type == "debug":
		if args.verbosity >= required_verbosity:
			print("DEBUG: mausbpf2exb:", *messages)
	elif type == "debug_multiple":
		if args.verbosity > 0:
			print("DEBUG: mausbpf2exb:", messages[args.verbosity - 1])
	elif type == "warning":
		print("WARNING: mausbpf2exb:", *messages, file = sys.stderr)
	elif type == "error":
		sys.exit("ERROR: mausbpf2exb: " + " ".join(messages))


if __name__ == "__main__":
	
	###### this block implements the parser as well as the help page etc
	parser = argparse.ArgumentParser(description = "This program converts speech annotation files in the BAS Partitur Format (BPF) to files for the Partitur-Editor of the EXMARaLDA speech software suite.")
	parser.add_argument("input", nargs = "?", help = "input file (.par, .bpf). If not specified, STDIN will be used")
	parser.add_argument("-o", "--output", help = "output file (.exb). If not specified, STDOUT will be used")
	parser.add_argument("-v", "--verbosity", type = int, default = 0, help = "verbosity level on a scale of 0 to 2 (default = 0)")
	parser.add_argument("-r", "--referenced-file", help = "name of referenced audio file")
	parser.add_argument("--ignore", dest = "ignored_tier", help = "name (or comma-separated list of names) of BPF tier(s) to be ignored during conversion")
	parser.add_argument("--leeway", type = int, default = 5, help = "minimum distance (in number of samples) which two timeline items in the .exb file have to be apart in order to be regarded as distinct items (default = 5). If below the threshold, the items will be merged.")
	parser.add_argument("--version", action = "store_true", help = "print version number and exit")
	parser.add_argument("--list-reference-tiers", action = "store_true", help = "print the hierarchy of class 4 tiers used to infer times of symbolic links and exit")
	args = parser.parse_args()
	######
	
	###### this block is for when the user passes parameters that only ask for information about the program
	if args.version:
		# display version number
		print("0.2")
		sys.exit()
	if args.list_reference_tiers:
		# list reference tiers
		for tier in ParFile.get_reference_tiers():
			print(tier)
		sys.exit()
	######
	
	###### this block takes care of the actual conversion
	# if an input file name has been provided, use that, otherwise, use STDIN
	if args.input:
		if path.exists(args.input):
			parFile = ParFile.fromFilepath(args.input)
		else:
			echo("The input file '{}' could not be found.".format(args.input), type = "error")
	else:
		ParFile(sys.stdin.read())
	exbFile = ExbFile()
	# remove tier(s) that the user decided to ignore
	if args.ignored_tier:
		ignored_tiers = args.ignored_tier.split(",")
		for ignored_tier in ignored_tiers:
			parFile.remove_tier(ignored_tier)
	par_to_exb(parFile, exbFile, args.referenced_file, args.leeway)
	outputstr = exbFile.generate()
	# if an output file name has been provided, use that, otherwise, use STDOUT
	if args.output:
		if path.exists(args.output):
			with open(args.output, mode = "w", encoding = "utf-8", newline = "\n") as outputfile:
				outputfile.write(outputstr)
		else:
			echo("The output file '{}' could not be found.".format(args.output), type = "error")
	else:
		print(outputstr)
	######

