#!/usr/bin/python3

#
# babackup_server
#
# Copyright (C) 2024 by John Heidemann <johnh@isi.edu>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License,
# version 2, as published by the Free Software Foundation.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA.
#


import argparse
import sys
import os
import os.path
import re
import shutil
import tempfile
import datetime
import subprocess
import logging
import pdb
# pdb.set_trace()

import yaml


class Program:
    def __init__(self):
        # if assertion_fails:
        #     raise Exception("Assertion failed")
        # or better sys.exit("Assertion failed")
        self.parse_args()
        if self.new_path is not None:
            self.configure_new_backup()
        else:
            self.check_backups()


    def verbose_log(self, s, status = 'info'):
        """our common logging function, both to stdout and the log"""
        if self.verbose:
            print(s)
        logging.info(s)


    def configure_logging(self):
        """set up formal logging, as given in the config"""
        logging_path = None
        logging_conf = self.conf.get('logging')
        if logging_conf is not None and logging_conf.get("filename") is not None:
            logging_path = self.conf['logging']['filename']
        if logging_path is None:
            logging_path = self.conf_dir + "/server.log"
        logging.basicConfig(filename  = logging_path, level = 'INFO', format="%(asctime)s: %(message)s")


    def parse_args(self):
        parser = argparse.ArgumentParser(description = 'backup things via rsync to a remote server from a client', epilog="""
babackup

        """)
        # see https://docs.python.org/3/library/argparse.html
        #  ArgumentParser.add_argument(name or flags...[, action][, nargs][, const][, default][, type][, choices][, required][, help][, metavar][, dest])

        #  parser.add_argument('--focus', help='focus on a given TARGET', choices=['us', 'nynj', 'coverage'], default='us')
        #  parser.add_argument('--output', '-o', help='output FILE')
        #  parser.add_argument('--duty-cycle', help='duty cycle (a float)', type=float)
        #  parser.add_argument('--type', '-t', choices=['pdf', 'png'], help='type of output (pdf or png)', default = 'pdf')
        #  parser.add_argument('--day', type=int, help='day to plot', default = None)
        parser.add_argument('--conf', '-c', help='use configuration FILE.yaml (default is ~/.config/babackup/server.yaml or /etc/babackup/server.yaml)')
        parser.add_argument('--name', '-N', help='use backup NAME, or define new backup NAME')
        parser.add_argument('--daily-to-keep', help='how many daily backups to keep', type=int, default=-1)
        parser.add_argument('--weekly-to-keep', help='how many weekly backups to keep', type=int, default=-1)
        parser.add_argument('--monthly-to-keep', help='how many monthly backups to keep', type=int, default=-1)
        parser.add_argument('--yearly-to-keep', help='how many yearly backups to keep', type=int, default=-1)
        parser.add_argument('--new-path', help='location of a new backup')
        parser.add_argument('--new-pub-key', help='public ssh key for new backup')
        parser.add_argument('--new-mode', help='new backup is rrsync or ssh')
        parser.add_argument('--check-current', help='check for fresh backups (defaults on)', default=True)
        parser.add_argument('--no-check-current', help='do not check for fresh backups', dest='check_current', action='store_false')
        parser.add_argument('--check-archive', help='check archives for aging, even if no new backups', action='store_true', default=None)
        parser.add_argument('--no-check-archive', help='check archives for aging, even if no new backups', dest='check_archive', action='store_false')
        parser.add_argument('--debug', '-d', help='debugging mode', action='store_true', default=False)
        parser.add_argument('--verbose', '-v', action='count')
        args = parser.parse_args()
        self.debug = args.debug
        self.verbose = args.verbose
        self.conf_path = args.conf
        self.name = args.name
        self.to_keep = dict()
        self.to_keep['daily'] = args.daily_to_keep
        self.to_keep['weekly'] = args.weekly_to_keep
        self.to_keep['monthly'] = args.monthly_to_keep
        self.to_keep['yearly'] = args.yearly_to_keep
        self.new_path = args.new_path
        self.new_pub_key = args.new_pub_key
        self.new_mode = args.new_mode
        self.check_current = args.check_current
        self.check_archive = args.check_archive
        self.DEFAULT_TO_KEEP = 10
        self.temp_dir = None
        return args


    def read_conf(self):
        """figure out what configuration file we're using, then read and return it"""
         # where?
        self.conf_dir = "/etc/babackup"
        if os.getuid() != 0:
            self.conf_dir = os.path.expanduser("~") + "/.config/babackup"
        # what?
        if self.conf_path is None:
            self.conf_path = self.conf_dir + "/server.yaml"
        # read it
        try:
            conf_stream = open(self.conf_path, 'r')
            self.conf = yaml.safe_load(conf_stream)
        except IOError:
            self.conf = dict()
        if self.conf.get('backups') is None:
            self.conf['backups'] = dict()
        # also set up logging
        self.configure_logging()


    def write_conf(self):
        if not os.path.isdir(self.conf_dir):
            os.makedirs(self.conf_dir, mode=0o755)
        with open(self.conf_path, 'w+') as conf_stream:
            yaml.dump(self.conf, conf_stream)


    def configure_new_backup(self):
        """configure a new backup
Update configuration files, generate keys, say what to do on the server, etc.
        """
        self.read_conf()
        if self.name is None:
            sys.exit("babackup: attempt to add new backup without specifying --name")
        if self.new_mode is None:
            sys.exit("babackup: attempt to add new backup without specifying --new-mode=MODE (rrsync or ssh)")
        mode = self.new_mode
        if self.new_path is None:
            sys.exit("babackup: attempt to add new backup without specifying --new-path=server/partial/or/full/path")
        if mode == 'rrsync' and self.new_pub_key is None:
            sys.exit("babackup_server: attempt to add new rrsync backup without specifying --new-pub-key='ssh-foo BASE64 keyname'")

        # and some sanity checking, since we're going to put stuff in authorized_keys
        if re.search(r"\s", self.new_path):
            sys.exit("babackup_server: rejecting --new-path that contains whitespace")
        if self.new_pub_key.find("\n") >= 0:
            sys.exit("babackup_server: rejecting --new-pub-key that contains newline")

        name = self.name
        backup = dict()
        backup["name"] = self.name
        backup["path"] = self.new_path
        backup["mode"] = mode
        if mode == 'rrsync':
            backup["pub-key"] = self.new_pub_key

        #
        # create the target directory
        #
        full_new_path = self.new_path
        if full_new_path[0] == '~':
            full_new_path = os.path.expanduser(full_new_path)
        if not os.path.isdir(full_new_path):
            if self.verbose:
                print(f"babackup_server: mkdir {full_new_path}")
            if not self.debug:
                os.mkdir(full_new_path)
                os.mkdir(full_new_path + "/current")
                os.mkdir(full_new_path + "/current/last")
                os.mkdir(full_new_path + "/current/last/data")

        #
        # change authorized_keys
        #
        if mode == 'rrsync':
            auth = 'command="/usr/local/bin/rrsync -wo ' + self.new_path + '/current",no-agent-forwarding,no-port-forwarding,no-pty,no-user-rc,no-X11-forwarding ' + self.new_pub_key
            if self.verbose:
                print(f"babackup_server: adding public key to ~/.ssh/authorized_keys, with rrsync\n\t{auth}")
            if not self.debug:
                auth_path = os.path.expanduser("~") + "/.ssh/authorized_keys"
                with open(auth_path, "a") as auth_stream:
                    auth_stream.write(auth + "\n")

        if self.conf['backups'] is None:
            self.conf['backups'] = []
        self.conf['backups'].append(backup)
        if self.debug:
            return
        self.write_conf()


        
    def check_backup(self, backup):
        """check one backup with configuration BACKUP"""

        name = backup.get("name")
        if name is None:
            sys.exit("babackup: backup is missing 'name:'")
        path = backup.get("path")
        if path is None:
            sys.exit(f"babackup_server: backup {name} has no path:")
        if path[0] == "~":
            path = os.path.expanduser(path)

        #
        # see if a this one finished a backup since last check
        #
        # 0. no begin: never started
        # 1. begin no end => in progress (or failed)
        # 2. begin and end, but end before begin => missed it and it's running again
        # 3. begin and end, but end after begin => good!
        #

        # no backup
        if not os.path.exists(f"{path}/current/begin"):
            self.verbose_log(f"babackup_server: {name} backup idle")
            return

        if not os.path.exists(f"{path}/current/end"):
            self.verbose_log(f"babackup_server: {name} backup is active")
            return

        begin_mtime = os.path.getmtime(f"{path}/current/begin")
        end_mtime = os.path.getmtime(f"{path}/current/end")
        if begin_mtime > end_mtime:
            self.verbose_log(f"babackup_server: {name} backup was complete but is active again")
            return

        #
        # commit!
        #
        # (Note that there is a test-to-use race going on here :-( )
        #
        self.verbose_log(f"babackup_server: {name} is rolling current to last")
        if not self.debug:
            if not os.path.exists(f"{path}/new"):
                os.mkdir(f"{path}/new")
            os.rename(f"{path}/current", f"{path}/new/last")
            os.rename(f"{path}/new", f"{path}/current")

        #
        # move the old last into archive
        #
        # Note that we trust the file mtime rather than contents,
        # since the contents came from the user.
        #
        # Only do this if that backup looks good (has begin and end files).
        # Otherwise we let that last linger, eventually to be garbage collected
        # when current/last goes away
        #
        if os.path.isdir(f"{path}/current/last/last") and os.path.exists(f"{path}/current/last/last/begin") and os.path.exists(f"{path}/current/last/last/end"):
            self.verbose_log(f"babackup_server: {name} is moving last/last to archive")
            last_begin_mtime = os.path.getmtime(f"{path}/current/last/last/begin")
            last_isotime = datetime.datetime.fromtimestamp(last_begin_mtime, datetime.timezone.utc).isoformat(timespec = 'minutes')
            # get rid of : in time, to be more filename friendly
            last_isotime = last_isotime.replace(":", "_")
            if not self.debug:
                if not os.path.isdir(f"{path}/archive"):
                    os.mkdir(f"{path}/archive")
                    if os.path.isdir(f"{path}/archive/{last_isotime}"):
                        # xxx: we have two things with the same time, so the rename will fail.
                        # So just skip it.
                        return
                    os.rename(f"{path}/current/last/last", f"{path}/archive/{last_isotime}")
            # we want to come back to this archive and check it later
            backup['archive_check'] = True
        
            
    def check_backup_archive(self, backup):
        """check the archive of one backup (configuration BACKUP) for outdated entries"""

        name = backup.get("name")
        if name is None:
            sys.exit("babackup: backup is missing 'name:'")
        path = backup.get("path")
        if path[0] == "~":
            path = os.path.expanduser(path)
        if path is None:
            sys.exit(f"babackup_server: backup {name} has no path:")
        if not os.path.isdir(f"{path}/archive"):
            self.verbose_log(f"babackup_server: {name} check has no archive ({path}/archive)")
            return

        self.verbose_log(f"babackup_server: {name} check archive")

        #
        # walk the archives
        # finding things to delete (ending ~)
        # and to date check.
        #
        path_to_remove = []
        to_date_check = []
        part_to_timestamp = dict()
        timestamp_to_part = dict()

        iso_matcher = re.compile(r"^\d{4}-?\d{2}-?\d{2}([tT]\d{2}[-:_]?\d{2})?([-+]\d{2}[-:_]\d{2})?$")
        for part in os.listdir(f"{path}/archive"):
            if part.endswith("~"):
                self.verbose_log(f"babackup_server: {name} check finds corpse {part}")
                path_to_remove.append(f"{path}/archive/{part}")
            elif iso_matcher.match(part):
                to_date_check.append(part)
                # undo our prior : removal
                clean_part = part.replace("_", ":")
#                # seems to require seconds, also, for fromisofromat with timezone to work
#                if len(clean_part) == 22:
#                    clean_part = clean_part[:16] + ":00" + clean_part[16:]
                part_to_timestamp[part] = datetime.datetime.fromisoformat(clean_part).timestamp()
                timestamp_to_part[part_to_timestamp[part]] = part
            else:
                if self.verbose > 1:
                    self.verbose_log(f"babackup_server: {name} check finds suprising file {part} in {path}/archive", "warn")

        #
        # apply the priorization algorithm
        #
        to_keep = dict()
        all_to_keep = 0
        keep_periods = ['daily', 'weekly', 'monthly', 'yearly'] 
        keep_period_interval = [23*60*60, 7*86400-3600, 30*86400-3600, 364*86400]
        if self.to_keep is None:
            self.to_keep = dict()
        for period in keep_periods:
            proposal = self.to_keep.get(period, self.DEFAULT_TO_KEEP)
            if proposal < 0:
                proposal = self.DEFAULT_TO_KEEP
            to_keep[period] = proposal
            all_to_keep += to_keep[period]

        archive_count = len(to_date_check)
        if archive_count < all_to_keep:
            self.verbose_log(f"babackup_server: no need to age archive ({archive_count} < {all_to_keep} limit)")
            return

        pdb.set_trace()
        newest_to_oldest_timestamps = sorted(part_to_timestamp.values(), reverse=True)
        keeping_what = keep_periods[0]
        count_to_keep = to_keep[keeping_what]
        keep_periods = keep_periods[1:]
        distance_too_close = keep_period_interval[0]
        distance_too_close = keep_period_interval[1:]
        last_timestamp_kept = None
        timestamps_to_keep = []
        timestamps_to_retire = []
        for timestamp in newest_to_oldest_timestamps:
            if count_to_keep > 0 and ((last_timestamp_kept is None) or (last_timestamp_kept - timestamp > distance_too_close)):
                # keep this timestamp
                self.verbose_log(f"babackup_server: keep {keeping_what} " + timestamp_to_part[timestamp])
                timestamps_to_keep.append(timestamp)
                last_timestamp_kept = timestamp
                count_to_keep -= 1
                if count_to_keep <= 0:
                    if len(keep_periods) <= 0:
                    	keeping_what = 'no_more'
                    else:
                        count_to_keep = to_keep[keeping_what]
                        keep_periods = keep_periods[1:]
                        distance_too_close = keep_period_interval[0]
                        distance_too_close = keep_period_interval[1:]
            else:
                # too close, so drop it
                self.verbose_log(f"babackup_server: aging {keeping_what} " + timestamp_to_part[timestamp])
                timestamps_to_retire.append(timestamp)

        #
        # take action
        #
        # first rename them (easy)
        for timestamp in timestamps_to_retire:
            part = timestamp_to_part[timestamp]
            self.verbose_log(f"babackup_server: {name} will retire {path}/archive/{part}")
            if not self.debug:
                os.rename(f"{path}/archive/{part}", f"{path}/archive/{part}~")
            path_to_remove.append(f"{path}/archive/{part}~")
        # then remove them
        for old_path in path_to_remove:
            self.verbose_log(f"babackup_server: {name} is removing tree {old_path}")
            if not self.debug:
                shutil.rmtree(old_path)
            
        
    def check_backups(self):
        """check all backups (or whatever was specified with -N)"""
        self.read_conf()
        if self.check_current:
            for backup in self.conf['backups']:
                if (self.name is None or self.name == backup['name']) and backup.get("enabled", True):
                    self.check_backup(backup)
        #
        # Now that we've checked each, go back and look at aging.
        #
        force_check_archive = False
        if self.check_archive is not None:
            force_check_archive = self.check_archive
        for backup in self.conf['backups']:
            check_this_archive = False
            if backup.get("archive_check"):
                check_this_archive = True
            elif force_check_archive:
                check_this_archive = True
                if self.name is not None and self.name != backup.get('name', ''):
                    check_this_archive = False
            if check_this_archive:
                self.check_backup_archive(backup)


                
if __name__ == '__main__':
    Program()
    sys.exit(0)

