#!/usr/bin/env python3
from beancount import loader
from beancount.query import query
from beancount.parser import printer
import argparse
from tabulate import tabulate
from decimal import Decimal
from beancount.core.amount import Amount, add, sub, mul
from math import floor
from datetime import datetime, timedelta

class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

def draw_line():
  print('─' * 30)

def get_amount(inventory):
  if inventory.get_only_position() == None:
    return Amount(Decimal(0), "EUR").to_string()
  result = inventory.get_only_position().units
  return Amount(Decimal(round(result.number, 2)), result.currency).to_string()

def print_capital_gains(inventory):
  if inventory.get_only_position() == None:
    return Amount(Decimal(0), "EUR").to_string()
  result = inventory.get_only_position().units
  num = Amount(Decimal(round(result.number * -1, 2)), result.currency).to_string()
  return f"{bcolors.OKGREEN if result.number <= 0 else bcolors.FAIL}{num}{bcolors.ENDC}"

def print_perc_capital_gains(inventory, init_invest):
  if inventory.get_only_position() == None:
    return Amount(Decimal(0), "EUR").to_string()
  result = inventory.get_only_position().units
  invest = init_invest.get_only_position().units
  num = round((result.number * -1 / invest.number) * 100, 2)
  return f"{bcolors.OKGREEN if result.number <= 0 else bcolors.FAIL}{num} %{bcolors.ENDC}"

def calc_contributions(contributions):
  total = Decimal(0)
  for c in contributions:
    amount = c.position.units.number
    if amount > 0:
      total += amount
  return total

def get_returns(init_invest, end_invest, contributions):
  init = init_invest.get_only_position().units
  end = end_invest.get_only_position().units
  contr = calc_contributions(contributions)
  result = sub(sub(end, init), Amount(contr, "EUR"))
  return Amount(Decimal(round(result.number, 2)), result.currency).to_string()

def print_contributions(contributions):
  total = calc_contributions(contributions)
  return Amount(Decimal(round(total, 2)), "EUR").to_string()

def print_report(date, end_date, init_invest, end_invest, contributions, capital_gains):
  print(f"{bcolors.BOLD}Investment returns (date={date}){bcolors.ENDC}")
  draw_line()
  print(tabulate([
      [date, get_amount(init_invest)],
      [end_date, get_amount(end_invest)]
  ], headers=['Date', 'Balance']))

  print(tabulate([
      ["Contributions", print_contributions(contributions)],
      ["Returns", get_returns(init_invest, end_invest, contributions)],
      ["Capital gains", print_capital_gains(capital_gains)],
      ["Capital gains %", print_perc_capital_gains(capital_gains, init_invest)]
  ]))

def get_investments(entries, options, date, end_date):
  initial_query = f"SELECT convert(sum(position), \"EUR\") as position FROM date <= {date} WHERE account ~ '^Assets:Invest'"
  rtypes, rrows = query.run_query(
    entries, options, initial_query)
  end_query = f"SELECT convert(sum(position), \"EUR\") as position FROM date <= {end_date} WHERE account ~ '^Assets:Invest'"
  rtypes, erows = query.run_query(
    entries, options, end_query)
  return rrows[0].position, erows[0].position

def get_contributions(entries, options, start_date, end_date):
  contributions_query = f"SELECT position FROM date <= {end_date} WHERE account ~ '^Assets:Liquid:R4:EUR' AND date >= {start_date}"
  rtypes, rrows = query.run_query(
    entries, options, contributions_query)
  return rrows

def get_capital_gains(entries, options, date, end_date):
  q = f"SELECT convert(sum(position), \"EUR\") as position FROM date <= {end_date} WHERE account ~ '^Income:Invest:R4:CapitalGains' AND date >= {date}"
  rtypes, rrows = query.run_query(
    entries, options, q)
  return rrows[0].position

def main():
  parser = argparse.ArgumentParser(description='Generate investments report')
  parser.add_argument('date', metavar='date', type=str, nargs=1,
                      help='Report date in ISO format (e.g. 1970-01-01)')

  args = parser.parse_args()
  date = args.date[0]

  filename = "ledger/main.beancount"
  entries, errors, options = loader.load_file(filename)

  if errors:
    printer.print_errors(errors)

  d = date.split("-")
  end_date = datetime(int(d[0]) + 1, int(d[1]), int(d[2]))
  end_date = end_date.strftime("%Y-%m-%d")
  init_invest, end_invest = get_investments(entries, options, date, end_date)
  contributions = get_contributions(entries, options, date, end_date)
  capital_gains = get_capital_gains(entries, options, date, end_date)
  print_report(date, end_date, init_invest, end_invest, contributions, capital_gains)

main()