import marimo

__generated_with = "0.17.0"
app = marimo.App(width="medium")


@app.cell
def _():
    import marimo as mo
    from komodo import get_snowflake_connection

    conn = get_snowflake_connection()
    cursor = conn.cursor()
    cursor.execute("USE DATABASE DATA;")
    return conn, cursor, mo


@app.cell
def _(conn, mo):
    _df = mo.sql(
        f"""
        SHOW SCHEMAS;
        """,
        engine=conn
    )
    return


@app.cell
def _(conn, mo):
    _df = mo.sql(
        f"""
        SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'PROD';
        """,
        engine=conn
    )
    return


@app.cell
def _(cursor):
    import polars as pl

    SCHEMA = "PUT_YOUR_SCHEMA_HERE" # PROD
    TABLE = "PUT_YOUR_TABLE_HERE" # T1_ALL_TX_CLAIMS_RX


    rx_query = f"SELECT * FROM {SCHEMA}.{TABLE} LIMIT 100000;" 
    cursor.execute(rx_query)
    rx_pl = pl.from_pandas(cursor.fetch_pandas_all())
    return pl, rx_pl


@app.cell
def _(mo, rx_pl):
    # Get unique brand names and add "All" to the list for the dropdown
    brand_names = ["All"] + rx_pl["BRAND_NAME"].unique().sort().to_list()

    # Create the dropdown UI element
    brand_selector = mo.ui.dropdown(
        options=brand_names,
        value="All",
        label="Select a Brand Name:"
    )

    # Display the dropdown
    brand_selector
    return (brand_selector,)


@app.cell
def _(brand_selector, mo, pl, rx_pl):
    import altair as alt
    from datetime import date

    # Filter the dataframe based on the dropdown selection
    if brand_selector.value == "All":
        filtered_df = rx_pl
        title_brand = "All Brands"
    else:
        filtered_df = rx_pl.filter(pl.col("BRAND_NAME") == brand_selector.value)
        title_brand = brand_selector.value

    # --- New: Filter data to the last 2 years ---
    # Find the most recent date in the data
    max_date = filtered_df.select(pl.max("FILL_DATE")).item()

    # Calculate the date 2 years before the most recent date
    if max_date:
        start_date = max_date.replace(year=max_date.year - 2)
        time_filtered_df = filtered_df.filter(pl.col("FILL_DATE") >= start_date)
    else:
        # Handle case where dataframe is empty or has no dates
        time_filtered_df = filtered_df
    # --- End of new code ---


    # Extract month from FILL_DATE and calculate the sum of DAYS_SUPPLY
    monthly_supply = (
        time_filtered_df
        .with_columns(
            # Create a 'month' column formatted as 'YYYY-MM' for proper sorting
            pl.col("FILL_DATE").dt.strftime("%Y-%m").alias("month")
        )
        .group_by("month")
        .agg(
            pl.sum("DAYS_SUPPLY").alias("total_days_supply")
        )
        .sort("month")  # Sort chronologically
    )

    # Create the Altair bar chart for monthly supply
    monthly_chart = alt.Chart(monthly_supply).mark_bar().encode(
        x=alt.X('month:O', title='Month', sort=None),  # Data is pre-sorted
        y=alt.Y('total_days_supply:Q', title='Total Days Supply'),
        tooltip=[
            alt.Tooltip('month:O', title='Month'),
            alt.Tooltip('total_days_supply:Q', title='Total Days Supply', format=',')
        ]
    ).properties(
        title=f"Total Monthly Days Supply for {title_brand} (Last 2 Years)",
        width=600
    )

    mo.ui.altair_chart(monthly_chart)
    return


if __name__ == "__main__":
    app.run()
