
from pyspark.sql import SparkSession
import sys

# accept s3 bucket name
if len(sys.argv) != 2:
    print("Usage: pass s3_bucket_name")
    sys.exit(1)
s3_bucket_name = sys.argv[1]

# start session
spark = SparkSession.builder.appName("ny_taxi_summary").getOrCreate()

# Read green zone files and create TempView
s3_green_zone_path = "s3://{}/output_data/green_zone/*".format(s3_bucket_name)
df_green_zone = spark.read.parquet(s3_green_zone_path)
df_green_zone.createOrReplaceTempView("src_green_zone")


# Read yellow zone files and create TempView
s3_yellow_zone_path = "s3://{}/output_data/yellow_zone/*".format(s3_bucket_name)
df_yellow_zone = spark.read.parquet(s3_yellow_zone_path)
df_yellow_zone.createOrReplaceTempView("src_yellow_zone")


# Create ny_taxi_summary, join green_zone and yellow_zone
df_ny_taxi_summary = spark.sql("select  coalesce(green.pu_service_zone, yellow.pu_service_zone) as  pu_service_zone, coalesce(green.pulocationid, yellow.pulocationid ) as pulocationid, coalesce(green.do_service_zone, yellow.do_service_zone) as  do_service_zone, coalesce(green.dolocationid, yellow.dolocationid) as dolocationid , cast(sum(yellow.passenger_count) as decimal(10,0)) passenger_count, sum(yellow.trip_distance) trip_distance,  sum(yellow.fare_amount) fare_amount , sum(yellow.extra) extra , sum(yellow.mta_tax) mta_tax, sum(yellow.tip_amount) tip_amount, sum(yellow.tolls_amount) tolls_amount, sum(yellow.improvement_surcharge) improvement_surcharge, sum(yellow.total_amount) total_amount, sum(yellow.congestion_surcharge) congestion_surcharge, sum(yellow.airport_fee) as airport_fee from src_yellow_zone yellow full outer join src_green_zone green on yellow.pulocationid = green.pulocationid and yellow.dolocationid = green.dolocationid group by 1,2,3,4;") 
df_ny_taxi_summary.show(10)

# Write ny_taxi_summary as a parquet file
s3_ny_taxi_summary_path = "s3://{}/output_data/ny_taxi_summary/".format(s3_bucket_name)
df_ny_taxi_summary.write.parquet(s3_ny_taxi_summary_path)

