#!/usr/bin/env python3
from beancount import loader
from beancount.query import query
from beancount.core.data import Custom
from beancount.core.amount import Amount, add, sub
from beancount.parser import printer
import argparse
from datetime import date
from dateutil.relativedelta import relativedelta
from tabulate import tabulate
from decimal import Decimal
from functools import reduce

class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

def get_budget_entries(entries, period, start_date):
  budgets = []
  for entry in entries:
    if isinstance(entry, Custom) and entry.values[1].value == period and entry.date <= date.fromisoformat(start_date):
      budgets.append({ "date": entry.date, "account": entry.values[0].value, "period": entry.values[1].value, "budget": entry.values[2].value })
  return budgets

def get_expenses(entries, options, period, start_date):
  period_delta = relativedelta(months=1) if period == "monthly" else relativedelta(years=1)
  end_date = date.fromisoformat(start_date) + period_delta
  expenses_query = f"SELECT account, sum(position) FROM OPEN ON {start_date} CLOSE ON {end_date.isoformat()} WHERE account ~ \"Expenses\""
  rtypes, rrows = query.run_query(
    entries, options, expenses_query)
  expenses = {}
  for row in rrows:
    expenses[row.account] = row.sum_position
  return expenses

def build_budget(budget_entries, expenses):
  result = []
  for entry in budget_entries:
    expense = Amount(Decimal(0), entry["budget"].currency)
    expense_perc = 0
    remaining = entry["budget"]
    if entry["account"] in expenses:
      expense = expenses[entry["account"]].get_only_position()
      expense_perc = (expense.units.number / entry["budget"].number) * 100
      remaining = sub(remaining, expense.units)
    result.append({
      "Account": entry["account"],
      "Budget": entry["budget"].to_string(),
      "Expense": expense,
      "Expense (%)": "{}{:,.2f}%{}".format(bcolors.FAIL if expense_perc >= 100 else '', expense_perc, bcolors.ENDC),
      "Remaining": remaining
    })
  return result

def print_report(budget_report, period, start_date, budget_sum, expenses_sum):
  print(f"Budget Report (period={period}, start_date={start_date})")
  print(f"Budget: {budget_sum}")
  print(f"{bcolors.FAIL if expenses_sum >= budget_sum else ''}Expenses: {expenses_sum}{bcolors.ENDC}")
  headings = ['Account', 'Budget', 'Expense', '(%)', 'Remaining',]
  print(tabulate(budget_report, headers="keys", numalign="right", floatfmt=".2f"))

def main():
  parser = argparse.ArgumentParser(description='Generate budget report')
  parser.add_argument('start_date', metavar='start_date', type=str, nargs=1,
                      help='Start date (end date will be one month after if monthly report or one year after if yearly report)')
  parser.add_argument('-p', metavar='period', type=str, choices=["monthly", "yearly"], default="monthly", required=False,
                      help='Period (monthly or yearly)')

  args = parser.parse_args()
  start_date = args.start_date[0]
  period = args.p

  filename = "ledger/main.beancount"
  entries, errors, options = loader.load_file(filename)

  if errors:
    printer.print_errors(errors)

  budget_entries = get_budget_entries(entries, period, start_date)
  # TODO: Multiple currencies
  budget_sum = reduce(lambda a, b: add(a, b["budget"]), budget_entries, Amount(Decimal(0), budget_entries[0]["budget"].currency))
  expenses = get_expenses(entries, options, period, start_date)
  filtered_expenses = {}
  for entry in budget_entries:
    if entry["account"] in expenses:
      filtered_expenses[entry["account"]] = expenses[entry["account"]]
  expenses_sum = reduce(lambda a, b: add(a, b.get_only_position().units),
    filtered_expenses.values(),
    Amount(Decimal(0), budget_entries[0]["budget"].currency))
  budget_report = build_budget(budget_entries, expenses)
  print_report(budget_report, period, start_date, budget_sum, expenses_sum)

main()