import json
import logging
import os

import boto3


logger = logging.getLogger()
region_name = os.environ["AWS_REGION"] if os.environ.get("AWS_REGION") else "us-east-1"
lambda_client = boto3.client("lambda", region_name=region_name)
comprehend_client = boto3.client("comprehend", region_name=region_name)
firehose_client = boto3.client("firehose", region_name=region_name)


def _get_comprehend_input(ticket_title, ticket_description):
    """Build comprehend input and return. You can customize this function with your custom fields."""
    return f"{ticket_title}\n{ticket_description}"


def _get_top_classification(classifications, minimum_score):
    """Get the classification with the top score. This function is used to identify top scored resource/operation."""
    sorted_classifications = sorted(classifications, key=lambda k: k["Score"], reverse=True)
    if len(sorted_classifications) and sorted_classifications[0].get("Score") >= minimum_score:
        return sorted_classifications[0]
    else:
        return {"Name": None, "Score": None}


def _classify_ticket(comprehend_input):
    """Classify the ticket with the comprehend."""
    # Get minimum score
    operation_score_threshold = float(os.environ["OPERATION_SCORE_THRESHOLD"])
    resource_score_threshold = float(os.environ["RESOURCE_SCORE_THRESHOLD"])
    
    # Get comprehend classifer arns
    operation_classifier_arn = os.environ["COMPREHEND_TICKET_OPERATION_CLASSIFIER_ARN"]
    resource_classifier_arn = os.environ["COMPREHEND_TICKET_RESOURCE_CLASSIFIER_ARN"]

    # Classify with the comprehend
    operations = comprehend_client.classify_document(
        EndpointArn=operation_classifier_arn, Text=comprehend_input
    ).get("Classes")
    resources = classification_result = comprehend_client.classify_document(
        EndpointArn=resource_classifier_arn, Text=comprehend_input
    ).get("Classes")
    logger.debug(f"Operations: {json.dumps(operations)}")
    logger.debug(f"Resources: {json.dumps(resources)}")

    # Get top scored operation and resource
    top_operation = _get_top_classification(operations, operation_score_threshold)
    top_resource = _get_top_classification(resources, resource_score_threshold)
    logger.info(f"Top Operation: {str(top_operation)}")
    logger.info(f"Top Resource: {str(top_resource)}")

    # Return the classification
    return {
        "OperationClasses": operations,
        "ResourceClasses": resources,
        "TopOperationType": top_operation.get("Name"),
        "TopOperationScore": top_operation.get("Score"),
        "TopResourceType": top_resource.get("Name"),
        "TopResourceScore": top_resource.get("Score"),
    }


def _call_ticket_handler(ticket_id, ticket_title, ticket_description, ticket_creation_time, top_operation_type, top_resource_type):
    ticket_handler_arn = os.environ["TICKETHANDLER_FUNCTION_ARN"]
    ticket_handler_response = lambda_client.invoke(
        FunctionName=ticket_handler_arn,
        Payload=json.dumps({"TicketId": ticket_id, "TicketTitle": ticket_title, "TicketDescription": ticket_description, "TicketCreationTime": ticket_creation_time, "TopOperationType": top_operation_type, "TopResourceType": top_resource_type}),
    )


def _push_classification(ticket_id, ticket_title, ticket_description, ticket_creation_time, top_operation_type, top_resource_type):
    delivery_stream_name = os.environ["KINESIS_FIREHOSE_DELIVERY_STREAM_NAME"]
    firehose_response = firehose_client.put_record(
        DeliveryStreamName=delivery_stream_name,
        Record={
            "Data": json.dumps({"id": ticket_id, "title": ticket_title, "description": ticket_description, "creation_time": ticket_creation_time, "operation": top_operation_type, "resource": top_resource_type}),
        }
    )
    logger.debug(f"Firehose response: {firehose_response}")


def main(event, context):
    # Set log level
    logger.setLevel(logging.DEBUG)

    # Parse event and get ticket metadata
    ticket_id = event.get("TicketId")
    ticket_title = event.get("TicketTitle")
    ticket_description = event.get("TicketDescription")
    ticket_creation_time = event.get("TicketCreationTime")
    if not ticket_id:
        raise Exception("Ticket ID (TicketId) is not provided.")
    if not ticket_title:
        raise Exception("Ticket title (TicketTitle) is not provided.")
    if not ticket_description:
        raise Exception("Ticket description (TicketDescription) is not provided.")
    if not ticket_creation_time:
        raise Exception("Ticket creation time (TicketCreationTime) is not provided.")
    logger.info(f"Ticket ID: {ticket_id}")
    logger.debug(f"Ticket Title: {ticket_title}")
    logger.debug(f"Ticket Description: {ticket_description}")
    logger.debug(f"Ticket Creation Time: {ticket_creation_time}")

    # Get comprehend input
    comprehend_input = _get_comprehend_input(ticket_title, ticket_description)
    logger.debug(f"Comprehend input: {comprehend_input}")
    
    # Classify the ticket
    classification_result = _classify_ticket(comprehend_input)
    top_operation_type = classification_result.get("TopOperationType")
    top_resource_type = classification_result.get("TopResourceType")
    logger.info(f"Classifier result: {json.dumps(classification_result, indent=4)}")

    # Push the classifications into Redshift cluster via Kinesis Data Firehose
    _push_classification(ticket_id, ticket_title, ticket_description, ticket_creation_time, top_operation_type, top_resource_type)
    
    # Call the ticket handler Lambda function with the classification.
    _call_ticket_handler(ticket_id, ticket_title, ticket_description, ticket_creation_time, top_operation_type, top_resource_type)

    return
