import argparse
import time
import logging
import grpc
import os
import urllib
import urllib.request
import base64
import json
from importlib.machinery import SourceFileLoader
import threading
from google.protobuf import text_format
from google.protobuf import json_format

from nuance.tts.v1.synthesizer_pb2 import *
from nuance.tts.v1.synthesizer_pb2_grpc import *

thread_context = threading.local()
thread_context.num_synthesis = 0

oauth_mutex = threading.Lock()
oauth_token_expiry_threshhold_seconds = 30
oauth_token_expiry_seconds = 0
oauth_token = None

num_iterations = 0

total_first_chunk_latency = 0
total_synthesis = 0

args = None

# Generates the .wav file header for a given set of parameters
def generate_wav_header(sampleRate, bitsPerSample, channels, datasize, formattype):
    # (4byte) Marks file as RIFF
    o = bytes("RIFF", 'ascii')
    # (4byte) File size in bytes excluding this and RIFF marker
    o += (datasize + 36).to_bytes(4, 'little')
    # (4byte) File type
    o += bytes("WAVE", 'ascii')
    # (4byte) Format Chunk Marker
    o += bytes("fmt ", 'ascii')
    # (4byte) Length of above format data
    o += (16).to_bytes(4, 'little')
    # (2byte) Format type (1 - PCM)
    o += (formattype).to_bytes(2, 'little')
    # (2byte) Will always be 1 for TTS
    o += (channels).to_bytes(2, 'little')
    # (4byte)
    o += (sampleRate).to_bytes(4, 'little')
    o += (sampleRate * channels * bitsPerSample // 8).to_bytes(4, 'little')  # (4byte)
    o += (channels * bitsPerSample // 8).to_bytes(2,'little')               # (2byte)
    # (2byte)
    o += (bitsPerSample).to_bytes(2, 'little')
    # (4byte) Data Chunk Marker
    o += bytes("data", 'ascii')
    # (4byte) Data size in bytes
    o += (datasize).to_bytes(4, 'little')

    return o

def send_http_request_with_json_response(url, headers, data, method):
    request = urllib.request.Request(url=url, headers=headers, data=data, method=method)

    with urllib.request.urlopen(request) as response:
        response = response.read().decode('utf-8')
        json_response = json.loads(response)
        return json_response

def get_oauth2_token():
    global oauth_token
    global oauth_token_expiry_seconds
    global oauth_token_expiry_threshhold_seconds

    if args.oauthURL is None:
        return None
    
    current_time = time.monotonic()

    oauth_mutex.acquire()
    try:
        if oauth_token and oauth_token_expiry_seconds - oauth_token_expiry_threshhold_seconds > current_time:
            log.debug('OAuth token is still valid')
            return oauth_token

        log.info("Obtaining auth token (Client ID: {}, URL: {})".format(args.clientID, args.oauthURL))

        encoded_credentials = base64.standard_b64encode("{}:{}".format(args.clientID, args.clientSecret).encode()).decode('utf-8')
        headers = { 'Authorization' : "Basic {}".format(encoded_credentials)  }

        data = {
            'grant_type': 'client_credentials',
            'scope': args.oauthScope,
        }

        request = urllib.request.Request(url=args.oauthURL, headers=headers, data=urllib.parse.urlencode(data).encode(), method='POST')
        with urllib.request.urlopen(request) as response:
            response = response.read().decode('utf-8')
            json_response = json.loads(response)

            oauth_token = json_response["access_token"]
            oauth_token_expiry_seconds = time.monotonic() + json_response["expires_in"]
        
            log.debug("Token TTL: %d" % json_response["expires_in"])
            return json_response["access_token"]
    except urllib.error.HTTPError as err:
        raise Exception("Failed to obtain authentication token. Status: {}, Error: {}".format(err.code, err.read().decode()))
    finally:
        oauth_mutex.release()


def send_get_voices_request(grpc_client, request, get_all_voices=False):
    log.info("Sending GetVoices request")

    client_span = None
    get_voices_span = None
    metadata = []
    http_headers = {}
    if args.sendHTTP:
        request = json_format.MessageToJson(request).encode()

    if args.jaeger:
        log.debug("Injecting Jaeger span context into request")
        client_span = tracer.start_span("Client.gRPC")
        get_voices_span = tracer.start_span(
            "Client.GetVoices", child_of=client_span)
        carrier = dict()
        tracer.inject(get_voices_span.context,
                      opentracing.propagation.Format.TEXT_MAP, carrier)
        metadata.append(('uber-trace-id', carrier['uber-trace-id']))
    
    if args.clientRequestID:
        metadata.append(('x-client-request-id', args.clientRequestID))

    if args.neural:
        log.info("Adding x-nuance-tts-neural header")
        metadata.append(('x-nuance-tts-neural', 'true'))
        
    if args.sendHTTP:
        current_oauth_token = get_oauth2_token()
        if current_oauth_token:
            http_headers['Authorization'] = 'Bearer {}'.format(current_oauth_token)
    
        json_response = send_http_request_with_json_response(url=args.serverUrl + "/api/v1/voices/", data=request, headers=http_headers, method="GET")
        log.info(json.dumps(json_response, sort_keys=True,
                            indent=2, separators=(',', ': ')))
    else:
        if args.getAllVoices:
            get_voices_rpc = grpc_client.GetAllVoices
        else:
            get_voices_rpc = grpc_client.GetVoices
        
        response, call = get_voices_rpc.with_call(request=request, timeout=args.timeoutSeconds, metadata=metadata)
        log.info(text_format.MessageToString(response))
        for key, value in call.trailing_metadata():
            log.info('Synthesis client received trailing metadata: key=%s value=%s' % (key, value))

    if get_voices_span:
        get_voices_span.finish()
    if client_span:
        client_span.finish()



def process_synthesis_response(request, response, start, synthesis_span, client_span, return_package, request_info):
    global args

    waveheader = ""
    metadata = []
    call_credentials = None

    if args.sendHTTP:
        decoded_audio_response = base64.b64decode(response["audio"])
        response = json_format.Parse(json.dumps(response), UnarySynthesisResponse())
        response.audio = decoded_audio_response

    if args.sendUnary or args.sendHTTP:
        if response.status.code != 200:
            if response.HasField("events"):
                log.info("Received events")
                log.info(text_format.MessageToString(response.events))
            log.error("Received status response: FAILED")
            log.error("Code: {}, Message: {}".format(
                response.status.code, response.status.message))
            log.error('Error: {}'.format(response.status.details))
            return

    if args.sendUnary or args.sendHTTP or response.HasField("audio"):
        log.info("Received audio: %d bytes" % len(response.audio))
        if not return_package["received_first_audio_chunk"]:
            return_package["received_first_audio_chunk"] = True
            latency = time.monotonic() - start
            log.info("First chunk latency: {} seconds".format(latency))
            global total_first_chunk_latency
            total_first_chunk_latency = total_first_chunk_latency + latency
            log.info("Average first-chunk latency (over {} synthesis requests): {} seconds".format(
                total_synthesis, total_first_chunk_latency/(total_synthesis)))

        if args.saveAudio:
            if args.saveAudioAsWav:
                if request.audio_params.audio_format.HasField("ogg_opus") or request.audio_params.audio_format.HasField("opus"):
                    log.warn(
                        "Cannot save wave format for Opus, ignoring")
                else:
                    return_package["currentaudiolen"] += len(response.audio)
                    waveheader = generate_wav_header(
                        request_info["sampleRate"], request_info["bitsPerSample"], request_info["channels"], return_package["currentaudiolen"], request_info["audioformat"])
                    if return_package["audio_file"] != None:
                        return_package["audio_file"].seek(0, 0)
                        return_package["audio_file"].write(waveheader)
                        return_package["audio_file"].seek(0, 2)
            if return_package["audio_file"] != None:
                return_package["audio_file"].write(response.audio)
        if args.saveAudioChunks:
            if request.audio_params.audio_format.HasField("ogg_opus"):
                log.warn(
                    "Cannot save separate audio chunks for Ogg Opus, ignoring")
            else:
                return_package["num_chunks"] = return_package["num_chunks"] + 1
                chunk_file_name = "%s_i%d_s%d_c%d.%s" % (
                    thread_context.file, request_info["num_iterations"], thread_context.num_synthesis, return_package["num_chunks"], request_info["extension"])
                chunk_file_name = os.path.join(args.audioDir, chunk_file_name)
                chunk_audio_file = open(chunk_file_name, "wb")
                if args.saveAudioAsWav:
                    if request.audio_params.audio_format.HasField("opus"):
                        log.warn(
                            "Cannot save audio chunks as wav for Opus, ignoring")
                    else:
                        # Adding wav header before writing to audio file
                        waveheader = generate_wav_header(
                            request_info["sampleRate"], request_info["bitsPerSample"], request_info["channels"], len(response.audio), request_info["audioformat"])
                        chunk_audio_file.write(waveheader)
                chunk_audio_file.write(response.audio)
                chunk_audio_file.close()
                log.info("Wrote audio chunk to %s" % chunk_file_name)
    if response.HasField("events"):
        log.info("Received events")
        log.info(text_format.MessageToString(response.events))

    if response.HasField("status"):
        if response.status.code == 200:
                log.info("Received status response: SUCCESS")
        else:
            log.error("Received status response: FAILED")
            log.error("Code: {}, Message: {}".format(
                response.status.code, response.status.message))
            log.error('Error: {}'.format(response.status.details))
    return return_package


def send_synthesis_request(grpc_client, request, num_iterations, metadata=None):
    global total_synthesis
    total_synthesis = total_synthesis + 1
    global args

    audio_file = None
    audio_file_name = ""
    extension = ""
    sampleRate = 0
    bitsPerSample = 0
    channels = 1
    audioformat = 0
    currentaudiolen = 0
    num_chunks = 0
    metadata = []
    http_headers = {}
    client_span = None
    synthesis_span = None
    received_first_audio_chunk = False
    call_credentials = None

    thread_context.num_synthesis = thread_context.num_synthesis + 1

    if args.saveAudio or args.saveAudioChunks:
        if request.audio_params.audio_format.HasField("pcm"):
            extension = "pcm"
            sampleRate = request.audio_params.audio_format.pcm.sample_rate_hz
            bitsPerSample = 16
            audioformat = 1
        elif request.audio_params.audio_format.HasField("alaw"):
            extension = "alaw"
            bitsPerSample = 8
            sampleRate = 8000
            audioformat = 6
        elif request.audio_params.audio_format.HasField("ulaw"):
            extension = "ulaw"
            bitsPerSample = 8
            sampleRate = 8000
            audioformat = 7
        elif request.audio_params.audio_format.HasField("ogg_opus"):
            extension = "ogg"
        elif request.audio_params.audio_format.HasField("opus"):
            extension = "opus"
        else:
            extension = "pcm"
            sampleRate = 22050
            bitsPerSample = 16
            audioformat = 1

        if args.saveAudioAsWav:
            if request.audio_params.audio_format.HasField("ogg_opus") or request.audio_params.audio_format.HasField("opus"):
                log.warn("Cannot set to wav format for Ogg Opus, ignoring")
            else:
                extension = "wav"

    if args.saveAudio:
        if request.audio_params.audio_format.HasField("opus"):
            log.warn("Cannot save whole audio for Opus, ignoring")
        else:
            audio_file_name = "%s_i%d_s%d.%s" % (
                thread_context.file, num_iterations, thread_context.num_synthesis, extension)
            audio_file_name = os.path.join(args.audioDir, audio_file_name)
            audio_file = open(audio_file_name, "wb")

    if args.appid:
        metadata.append(('x-nuance-client-id', args.appid))
        http_headers['x-nuance-client-id'] = args.appid

    if args.neural:
        log.info("Adding x-nuance-tts-neural header")
        metadata.append(('x-nuance-tts-neural', 'true'))
        http_headers['x-nuance-tts-neural'] = 'true'

    if args.clientRequestID:
        metadata.append(('x-client-request-id', args.clientRequestID))
        http_headers['x-client-request-id'] = args.clientRequestID

    if args.jaeger:
        log.debug("Injecting Jaeger span context into request")
        client_span = tracer.start_span("Client.gRPC")
        if args.sendUnary or args.sendHTTP:
            synthesis_span = tracer.start_span(
                "Client.UnarySynthesize", child_of=client_span)
        else:
            synthesis_span = tracer.start_span(
                "Client.Synthesize", child_of=client_span)
        carrier = dict()
        tracer.inject(synthesis_span.context,
                      opentracing.propagation.Format.TEXT_MAP, carrier)
        metadata.append(('uber-trace-id', carrier['uber-trace-id']))

    request_info = {"sampleRate": sampleRate, "bitsPerSample": bitsPerSample, "channels": channels, "audioformat": audioformat, "extension": extension, "num_iterations": num_iterations}
    return_package = {"received_first_audio_chunk": received_first_audio_chunk, "num_chunks": num_chunks, "currentaudiolen": currentaudiolen, "audio_file": audio_file, "audio_file_name": audio_file_name}

    start = time.monotonic()

    if args.sendUnary:
        response, call = grpc_client.UnarySynthesize.with_call(request=request, timeout=args.timeoutSeconds, metadata=metadata)
        log.info("Sending Unary Synthesis request")
        return_package = process_synthesis_response(request, response, start, synthesis_span, client_span, return_package, request_info)
        for key, value in call.trailing_metadata():
            log.info('Synthesis client received trailing metadata: key=%s value=%s' % (key, value))
    elif args.sendHTTP:
        current_oauth_token = get_oauth2_token()
        if current_oauth_token:
            http_headers['Authorization'] = 'Bearer {}'.format(current_oauth_token)

        json_response = send_http_request_with_json_response(url = args.serverUrl + "/api/v1/synthesize/", data=json_format.MessageToJson(request).encode(), headers=http_headers, method="POST")
        log.info("Sending HTTP Synthesis request")
        if json_response:
            return_package = process_synthesis_response(request, json_response, start, synthesis_span, client_span, return_package, request_info)
        else:
            log.error("Failed to get response from server!")
    else:
        log.info("Sending Synthesis request")
        responses = grpc_client.Synthesize(request=request, timeout=args.timeoutSeconds, metadata=metadata)
        for response in responses:
            return_package = process_synthesis_response(request, response, start, synthesis_span, client_span, return_package, request_info)
        for key, value in responses.trailing_metadata():
            log.info('Synthesis client received trailing metadata: key=%s value=%s' % (key, value))


    if args.saveAudio and return_package:
        if return_package["audio_file"] != None:
            return_package["audio_file"].close()
            log.info("Wrote audio to %s" % return_package["audio_file_name"])

    if synthesis_span:
        synthesis_span.finish()
    if client_span:
        client_span.finish()


def parse_args():
    global args
    parser = argparse.ArgumentParser(
        prog="client.py",
        usage="%(prog)s [-options]",
        add_help=False,
        formatter_class=lambda prog: argparse.HelpFormatter(
            prog, max_help_position=45, width=100)
    )

    options = parser.add_argument_group("options")
    options.add_argument("-h", "--help", action="help",
                         help="Show this help message and exit")
    options.add_argument("--appid", metavar="appID:client-id", nargs="?", help="Client ID or group name, prefixed with appID:")
    options.add_argument("--token", nargs="?", help=argparse.SUPPRESS)
    options.add_argument("-f", "--files", metavar="file", nargs="+",
                         help="List of flow files to execute sequentially, default=['flow.py']", default=['flow.py'])
    options.add_argument("-p", "--parallel", action="store_true",
                         help="Run each flow in a separate thread")
    options.add_argument("-i", "--iterations", metavar="num", nargs="?",
                         help="Number of times to run the list of files, default=1", default=1, type=int)
    options.add_argument("--infinite", action="store_true",
                         help="Run all files infinitely (overrides number of iterations)")
    options.add_argument("-t", "--timeoutSeconds", metavar="num", nargs="?",
                         help="Timeout in seconds for every RPC call, default=30", default=30, type=int)
    options.add_argument("-s", "--serverUrl", metavar="url", nargs="?",
                         help="NVC server URL, default=localhost:8080", default='localhost:8080')
    options.add_argument("--oauthURL", metavar="url", nargs="?",
                         help="OAuth 2.0 URL")
    options.add_argument("--clientRequestID", metavar="id", nargs="?",
                         help="Client-generated request ID")
    options.add_argument("--clientID", metavar="url", nargs="?",
                         help="OAuth 2.0 Client ID")
    options.add_argument("--clientSecret", metavar="url", nargs="?",
                         help="OAuth 2.0 Client Secret")
    options.add_argument("--oauthScope", metavar="url", nargs="?",
                         help="OAuth 2.0 Scope, default=tts", default='tts')
    options.add_argument("--secure", action="store_true",
                         help="Connect to the server using a secure gRPC channel")
    options.add_argument("--rootCerts",  metavar="file", nargs="?",
                         help="Root certificates when using a secure gRPC channel")
    options.add_argument("--privateKey",  metavar="file", nargs="?",
                         help="Certificate private key when using a secure gRPC channel")
    options.add_argument("--certChain",  metavar="file", nargs="?",
                         help="Certificate chain when using a secure gRPC channel")
    options.add_argument("--audioDir", metavar="dir", nargs="?",
                         help="Audio output directory, default=./audio", default='./audio')
    options.add_argument("--saveAudio", action="store_true",
                         help="Save whole audio to disk")
    options.add_argument("--saveAudioChunks", action="store_true",
                         help="Save each individual audio chunk to disk")
    options.add_argument("--saveAudioAsWav", action="store_true",
                         help="Save each audio file in the WAVE format")
    options.add_argument("--jaeger", metavar="addr", nargs="?", const='udp://localhost:6831',
                         help="Send UDP opentrace spans, default addr=udp://localhost:6831")
    options.add_argument("--sendUnary", action="store_true",
                         help="Receive one response (UnarySynthesize) instead of a stream of responses (Synthesize)")
    options.add_argument("--sendHTTP", action="store_true",
                         help="Send the requests using the HTTP-to-gRPC API")
    options.add_argument("--maxReceiveSizeMB", metavar="megabytes", nargs="?",
                         help="Maximum length of gRPC server response in megabytes, default=50 MB", default=50, type=int)
    options.add_argument("--getAllVoices", action="store_true",
                         help=argparse.SUPPRESS)
    options.add_argument("--neural", action="store_true",
                         help="Send the request to Neural TTS, if available.")
    args = parser.parse_args()


def initialize_tracing():
    if args.jaeger:
        print("Enabling Jaeger traces")
        global opentracing
        import opentracing
        import jaeger_client

        from urllib.parse import urlparse
        agent_addr = urlparse(args.jaeger)
        if not agent_addr.netloc:
            raise Exception(
                "invalid jaeger agent address: {}".format(args.jaeger))
        if not agent_addr.hostname:
            raise Exception(
                "missing hostname in jaeger agent address: {}".format(args.jaeger))
        if not agent_addr.port:
            raise Exception(
                "missing port in jaeger agent address: {}".format(args.jaeger))
        tracer_config = {
            'sampler': {
                'type': 'const',
                'param': 1,
            },
            'local_agent': {
                'reporting_host': agent_addr.hostname,
                'reporting_port': agent_addr.port
            },
            'logging': True
        }
        config = jaeger_client.Config(
            config=tracer_config, service_name='NVCClient', validate=True)
        global tracer
        tracer = config.initialize_tracer()


def create_channel():
    call_credentials = None
    channel = None

    if args.token:
        log.debug('Adding CallCredentials using token parameter')
        call_credentials = grpc.access_token_call_credentials(args.token)
    else:
        current_oauth_token = get_oauth2_token()
        if current_oauth_token:
            log.debug('Adding CallCredentials from OAuth endpoint')
            call_credentials = grpc.access_token_call_credentials(current_oauth_token)

    if args.secure:
        log.debug("Creating secure gRPC channel")
        root_certificates = None
        certificate_chain = None
        private_key = None
        if args.rootCerts:
            log.debug("Adding root certs")
            root_certificates = open(args.rootCerts, 'rb').read()
        if args.certChain:
            log.debug("Adding cert chain")
            certificate_chain = open(args.certChain, 'rb').read()
        if args.privateKey:
            log.debug("Adding private key")
            private_key = open(args.privateKey, 'rb').read()

        channel_credentials = grpc.ssl_channel_credentials(
            root_certificates=root_certificates, private_key=private_key, certificate_chain=certificate_chain)
        if call_credentials is not None:
            channel_credentials = grpc.composite_channel_credentials(
                channel_credentials, call_credentials)
        channel = grpc.secure_channel(args.serverUrl, credentials=channel_credentials, options=[
                                      ('grpc.max_receive_message_length', args.maxReceiveSizeMB * 1024 * 1024)])
    else:
        log.debug("Creating insecure gRPC channel")
        channel = grpc.insecure_channel(args.serverUrl, options=[(
            'grpc.max_receive_message_length', args.maxReceiveSizeMB * 1024 * 1024)])

    return channel


def worker_thread(file, num_iterations, list_of_requests):
    run_one_file(file, num_iterations, list_of_requests)

def run_one_file(file, num_iterations, list_of_requests):
    thread_context.num_synthesis = 0
    
    with create_channel() as channel:
        grpc_client = SynthesizerStub(channel=channel)
        log.info("Running file [%s]" % file)
        log.debug(list_of_requests)

        thread_context.num_synthesis = 0
        thread_context.file = os.path.basename(file)

        for request in list_of_requests:
            if isinstance(request, GetVoicesRequest):
                send_get_voices_request(grpc_client, request)
            elif isinstance(request, SynthesisRequest):
                send_synthesis_request(grpc_client, request, num_iterations)
            elif isinstance(request, (int, float)):
                log.info("Waiting for {} seconds".format(request))
                time.sleep(request)
        log.info("Done running file [%s]" % file)


def run():
    parse_args()

    log_level = logging.DEBUG
    global log
    log = logging.getLogger('')
    logging.basicConfig(
        format='%(asctime)s (%(thread)d) %(levelname)-5s %(message)s', level=log_level)

    if args.oauthURL:
        if args.clientID is None:
            log.error("OAuth 2.0 URL was supplied but client ID is missing")
            return
        elif args.clientSecret is None:
            log.error("OAuth 2.0 URL was supplied but client secret is missing")
            return

    initialize_tracing()
    get_oauth2_token()

    if (args.saveAudio or args.saveAudioChunks) and not os.path.exists(args.audioDir):
        log.info("Audio directory: {}".format(args.audioDir))
        os.mkdir(args.audioDir)
    
    if args.infinite:
        log.info("Setting iterations to infinity")
        args.iterations = 100**100

    for i in range(args.iterations):
        num_iterations = i + 1
        log.info("Iteration #{} out of {}".format(num_iterations, args.iterations))
        threads = []
        for file in args.files:
            absolute_path = os.path.abspath(file)
            module_name = os.path.splitext(absolute_path)[0]
            module = SourceFileLoader(module_name, absolute_path).load_module()
            if module.list_of_requests == None:
                raise Exception(
                    "Error importing [%s]: variable list_of_requests not defined" % file)
            if args.parallel:
                log.info("Running flows in parallel")
                thread = threading.Thread(target=worker_thread, args=[file, num_iterations, module.list_of_requests])
                threads.append(thread)
                thread.start()
            else:
                run_one_file(file, num_iterations, module.list_of_requests)
        for thread in threads:
            thread.join()
        log.info("Iteration #{} complete".format(num_iterations))

    if total_synthesis > 0:
        log.info("Average first-chunk latency (over {} synthesis requests): {} seconds".format(
            total_synthesis, total_first_chunk_latency/(total_synthesis)))

    if args.jaeger:
        tracer.close()
        # Need to give time to tracer to flush the spans: https://github.com/jaegertracing/jaeger-client-python/issues/50
        time.sleep(2)
    log.info("Done")


if __name__ == '__main__':
    run()
