#!/usr/bin/env python3

import argparse
import sys
import time
import logging
import grpc
import os
import requests as req
import base64
import json
from importlib.machinery import SourceFileLoader
import threading
from google.protobuf import text_format
from google.protobuf import json_format
from util import configure_logging

from nuance.asr.forgetme.v1 import forgetme_pb2
from nuance.asr.forgetme.v1 import forgetme_pb2_grpc

thread_context = threading.local()
thread_context.oauth_token = None
thread_context.num_request = 0
thread_context.oauth_token_expiry_seconds = 0
total_first_chunk_latency = 0
total_request = 0

args = None

def get_oauth2_token():
    if args.oauthURL is None:
        return

    current_time = time.monotonic()

    if thread_context.oauth_token and thread_context.oauth_token_expiry_seconds - 3 > current_time:
        log.debug('OAuth token is still valid')
        return
    if args.disableBasicAuth:
        request = {
            'grant_type': 'client_credentials',
            'scope': args.oauthScope,
            'client_id': args.clientID,
            'client_secret': args.clientSecret
        }
        log.info("Obtaining auth token (Client ID: {}, URL: {})".format(
            args.clientID, args.oauthURL))     
        try:    
            response = req.post(args.oauthURL, data=request)
            if response.status_code != 200:
                raise Exception("Status: {}".format(
                    response.status_code))
        except Exception as er:
            log.error("Failed to obtain authentication token, error {}".format(er))
            raise SystemExit               
    else:
        request = {
            'grant_type': 'client_credentials',
            'scope': args.oauthScope
        }
        log.info("Obtaining auth token using basicAuth(Client ID: {}, URL: {})".format(
            args.clientID, args.oauthURL))
        try:
            response = req.post(args.oauthURL, auth=(args.clientID, args.clientSecret), data=request)
            if response.status_code != 200:
                raise Exception("Status: {}".format(
                    response.status_code))
        except Exception as er:
            log.error("Failed to obtain authentication token from basicauth, error {}".format(er))
            raise SystemExit   

    json_response = response.json()
    thread_context.oauth_token = json_response["access_token"]
    thread_context.oauth_token_expiry_seconds = time.monotonic() + \
        json_response["expires_in"]
    log.debug("Token TTL: %d" % json_response["expires_in"])

def send_deletespeakerprofiles_request(grpc_client, request, metadata=None):
    log.info("Sending DeleteSpeakerProfiles request")

    metadata = []
    client_span = None
    request_span = None

    global total_request
    total_request = total_request + 1
    
    global args
    
    thread_context.num_request = thread_context.num_request + 1
    
    if args.userId:
        log.info("Override the request user_id with argument userId [%s]" % args.userId)
        request.user_id = args.userId
    
    if args.meta:
        with open(args.meta if type(args.meta)is str else '.metadata', 'r') as meta_file:
            for n, line in enumerate(meta_file):
                header, value = line.split(':', 1)
                metadata.append((header.strip(), value.strip()))
                
    if args.nmaid:
        metadata.append(('x-nuance-client-id', args.nmaid))
    
    if args.jaeger:
        log.debug("Injecting Jaeger span context into request")
        client_span = tracer.start_span("Client.gRPC")
        request_span = tracer.start_span(
            "Client.Training", child_of=client_span)
        carrier = dict()
        tracer.inject(request_span.context,
                      opentracing.propagation.Format.TEXT_MAP, carrier)
        metadata.append(('uber-trace-id', carrier['uber-trace-id']))
    
    start = time.monotonic()
    log.info("Sending request: {}".format(request))
    log.info("Sending metadata: {}".format(metadata))
    try:
        response = grpc_client.DeleteSpeakerProfiles(request=request, metadata=metadata)
    except Exception as er:
        log.error("Error while sending DeleteSpeakerProfiles request, error {}".format(er))
        raise SystemExit
    log.info("Received response: {}".format(response))
    latency = time.monotonic() - start
    global total_first_chunk_latency
    total_first_chunk_latency = total_first_chunk_latency + latency

    log.info("First chunk latency: {} seconds".format(latency))

    if request_span:
        request_span.finish()
    if client_span:
        client_span.finish()
		
def parse_args():
    global args
    parser = argparse.ArgumentParser(
        prog="forgetme-client.py",
        usage="%(prog)s [-options]",
        add_help=False,
        formatter_class=lambda prog: argparse.HelpFormatter(
            prog, max_help_position=45, width=100)
    )
    options = parser.add_argument_group("options")
    options.add_argument("-h", "--help", action="help",
                         help="Show this help message and exit")
    options.add_argument("--nmaid", nargs="?", help=argparse.SUPPRESS)
    options.add_argument("--token", nargs="?", help=argparse.SUPPRESS)
    options.add_argument("-f", "--files", metavar="file", nargs="+",
                         help="List of flow files to execute sequentially, default=['forgetme-flow.py']", default=['forgetme-flow.py'])
    options.add_argument('-l', '--loglevel', metavar='lvl', choices=['fatal','error','warn','info','debug'], default='info', 
	                     help='Log level: fatal, error, warn, default=info, debug')	
    options.add_argument('-L', '--logfile', metavar='fn', nargs='?', const=True, 
	                     help='log to file, default fn=fmsgcli-{datetimestamp}.log')	
    options.add_argument('-q', '--quiet', action='store_true', 
	                     help='Disable console logging')
    options.add_argument("-p", "--parallel", action="store_true",
                         help="Run each flow in a separate thread")
    options.add_argument("-i", "--iterations", metavar="num", nargs="?",
                         help="Number of times to run the list of files, default=1", default=1, type=int)
    options.add_argument("-s", "--serverUrl", metavar="url", nargs="?",
                         help="ForgetMe service URL, default=localhost:9080", default='localhost:9080')
    options.add_argument("-b", "--disableBasicAuth",
                         help="Basic auth is required for Mix-generated credentials, disable for others", action='store_true')            
    options.add_argument("--oauthURL", metavar="url", nargs="?",
                         help="OAuth 2.0 URL")
    options.add_argument("--clientID", metavar="url", nargs="?",
                         help="OAuth 2.0 Client ID")
    options.add_argument("--clientSecret", metavar="url", nargs="?",
                         help="OAuth 2.0 Client Secret")
    options.add_argument("--oauthScope", metavar="url", nargs="?",
                         help="OAuth 2.0 Scope, default=asr", default='asr.forgetme')
    options.add_argument("--secure", action="store_true",
                         help="Connect to the server using a secure gRPC channel")
    options.add_argument("--rootCerts",  metavar="file", nargs="?",
                         help="Root certificates when using a secure gRPC channel")
    options.add_argument("--privateKey",  metavar="file", nargs="?",
                         help="Certificate private key when using a secure gRPC channel")
    options.add_argument("--certChain",  metavar="file", nargs="?",
                         help="Certificate chain when using a secure gRPC channel")
    options.add_argument("--jaeger", metavar="addr", nargs="?", const='udp://localhost:6831',
                         help="Send UDP opentrace spans, default addr=udp://localhost:6831")
    options.add_argument('--meta', metavar='txtfile', nargs='?', const=True, 
	                     help='Read header:value metadata lines from file, default=.metadata', default=None)
    options.add_argument("--maxReceiveSizeMB", metavar="megabytes", nargs="?",
                         help="Maximum length of gRPC server response in megabytes, default=50 MB", default=50, type=int)
    options.add_argument("--userId",  metavar="url", nargs="?",
                         help="User ID of speaker profile, if provided overrides the request.user_id")			 
    args = parser.parse_args()

def initialize_tracing():
    if args.jaeger:
        print("Enabling Jaeger traces")
        global opentracing
        import opentracing
        import jaeger_client

        from urllib.parse import urlparse
        agent_addr = urlparse(args.jaeger)
        if not agent_addr.netloc:
            raise Exception(
                "invalid jaeger agent address: {}".format(args.jaeger))
        if not agent_addr.hostname:
            raise Exception(
                "missing hostname in jaeger agent address: {}".format(args.jaeger))
        if not agent_addr.port:
            raise Exception(
                "missing port in jaeger agent address: {}".format(args.jaeger))
        tracer_config = {
            'sampler': {
                'type': 'const',
                'param': 1,
            },
            'local_agent': {
                'reporting_host': agent_addr.hostname,
                'reporting_port': agent_addr.port
            },
            'logging': True
        }
        config = jaeger_client.Config(
            config=tracer_config, service_name='fmsgcli-client', validate=True)
        global tracer
        tracer = config.initialize_tracer()

def create_channel():
    call_credentials = None
    channel = None
    
    if args.token:
        log.debug('Adding CallCredentials using token parameter')
        call_credentials = grpc.access_token_call_credentials(args.token)
    else:
        get_oauth2_token()
        if thread_context.oauth_token:
            log.debug('Adding CallCredentials from OAuth endpoint')
            call_credentials = grpc.access_token_call_credentials(
                thread_context.oauth_token)
    
    if args.secure:
        log.debug("Creating secure gRPC channel")
        root_certificates = None
        certificate_chain = None
        private_key = None
        if args.rootCerts:
            log.debug("Adding root certs")
            root_certificates = open(args.rootCerts, 'rb').read()
        if args.certChain:
            log.debug("Adding cert chain")
            certificate_chain = open(args.certChain, 'rb').read()
        if args.privateKey:
            log.debug("Adding private key")
            private_key = open(args.privateKey, 'rb').read()

        channel_credentials = grpc.ssl_channel_credentials(
            root_certificates=root_certificates, private_key=private_key, certificate_chain=certificate_chain)
        if call_credentials is not None:
            channel_credentials = grpc.composite_channel_credentials(
                channel_credentials, call_credentials)
        channel = grpc.secure_channel(args.serverUrl, credentials=channel_credentials, options=[
                                      ('grpc.max_receive_message_length', args.maxReceiveSizeMB * 1024 * 1024)])
    else:
        log.debug("Creating insecure gRPC channel")
        channel = grpc.insecure_channel(args.serverUrl, options=[(
            'grpc.max_receive_message_length', args.maxReceiveSizeMB * 1024 * 1024)])

    return channel
        
def run_one_file(file, list_of_requests, module):
    thread_context.oauth_token = None
    thread_context.num_request = 0
    thread_context.oauth_token_expiry_seconds = 0
    
    with create_channel() as channel:
        grpc_client = forgetme_pb2_grpc.ForgetMeStub(channel=channel)
        log.info("Running file [%s]" % file)
        #log.debug(list_of_requests)

        thread_context.num_request = 0
        thread_context.file = os.path.basename(file)

        for request in list_of_requests:
            if isinstance(request, forgetme_pb2.DeleteSpeakerProfilesRequest):
                send_deletespeakerprofiles_request(grpc_client, request)
            elif isinstance(request, (int, float)):
                log.info("Waiting for {} seconds".format(request))
                time.sleep(request)
        log.info("Done running file [%s]" % file)        

def run():
    parse_args()

    global log
    log = logging.getLogger('fmsgcli')
    configure_logging(log, args)
	
    if args.oauthURL:
        if args.clientID is None:
            log.error("OAuth 2.0 URL was supplied but client ID is missing")
            return
        elif args.clientSecret is None:
            log.error("OAuth 2.0 URL was supplied but client secret is missing")
            return

    initialize_tracing()

    for i in range(args.iterations):
        global num_iterations
        num_iterations = i + 1
        log.info("Iteration #{}".format(num_iterations))
        threads = []
        for file in args.files:
            absolute_path = os.path.abspath(file)
            module_name = os.path.splitext(absolute_path)[0]
            module = SourceFileLoader(module_name, absolute_path).load_module()

            # module = importlib.import_module(basename)
            if module.list_of_requests == None:
                raise Exception(
                    "Error importing [%s]: variable list_of_requests not defined" % file)
            if args.parallel:
                log.info("Running flows in parallel")
                thread = threading.Thread(target=run_one_file, args=[
                                          file, module.list_of_requests])
                threads.append(thread)
                thread.start()
            else:
                log.info("Running flows in serial")
                run_one_file(file, module.list_of_requests, module)
        for thread in threads:
            thread.join()
        log.info("Iteration #{} complete".format(num_iterations))

    if total_request > 0:
        log.info("Average first-chunk latency (over {} requests): {} seconds".format(
            total_request, total_first_chunk_latency/(total_request)))

    if args.jaeger:
        tracer.close()
        # Need to give time to tracer to flush the spans: https://github.com/jaegertracing/jaeger-client-python/issues/50
        time.sleep(2)
    print("Done")


if __name__ == '__main__':
    run()
