#!/usr/bin/python3

#
# babackup_server
#
# Copyright (C) 2024 by John Heidemann <johnh@isi.edu>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License,
# version 2, as published by the Free Software Foundation.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA.
#


import argparse
import sys
import os
import platform
import os.path
import stat
import io
import re
import shutil
import tempfile
import datetime
import subprocess
import logging
import pdb
# pdb.set_trace()

import yaml


def subprocess_run_capture_output(cmd):
    """backwards compatible subprocess.run for capture_output"""
    (major, minor, patchlevel) = list(map(int, platform.python_version_tuple()))
    if major == 3 and minor >= 7:
        return subprocess.run(cmd, capture_output = True, encoding = 'utf-8')
    else:
        return subprocess.run(cmd, encoding = 'utf-8', stdout = subprocess.PIPE, stderr = subprocess.PIPE)


class Program:
    RRSYNC_PATH = "/usr/local/bin/rrsync"
    
    def __init__(self):
        # if assertion_fails:
        #     raise Exception("Assertion failed")
        # or better sys.exit("Assertion failed")
        self.parse_args()
        if self.new_server_path is not None:
            self.configure_new_backup()
        else:
            self.check_backups()


    def verbose_log(self, s, verbosity = 1):
        """our common logging function, both to stdout and the log"""
        if self.verbose >= verbosity:
            print(s)
        if verbosity == 1:
            logging.info(s)
        elif verbosity >= 2:
            logging.debug(s)


    def configure_logging(self):
        """set up formal logging, as given in the config"""
        logging_path = None
        logging_conf = self.conf.get('logging')
        if logging_conf is not None and logging_conf.get("filename") is not None:
            logging_path = self.conf['logging']['filename']
        if logging_path is None:
            logging_path = self.conf_dir + "/server.log"
        logging_dirname = os.path.dirname(logging_path)
        if not os.path.isdir(logging_dirname):
            os.path.mkdirs(logging_dirname)
        logging.basicConfig(filename  = logging_path, level = 'INFO', format="%(asctime)s: %(message)s")


    def parse_args(self):
        parser = argparse.ArgumentParser(description = 'backup things via rsync to a remote server from a client', epilog="""
babackup

        """)
        # see https://docs.python.org/3/library/argparse.html
        #  ArgumentParser.add_argument(name or flags...[, action][, nargs][, const][, default][, type][, choices][, required][, help][, metavar][, dest])

        #  parser.add_argument('--focus', help='focus on a given TARGET', choices=['us', 'nynj', 'coverage'], default='us')
        #  parser.add_argument('--output', '-o', help='output FILE')
        #  parser.add_argument('--duty-cycle', help='duty cycle (a float)', type=float)
        #  parser.add_argument('--type', '-t', choices=['pdf', 'png'], help='type of output (pdf or png)', default = 'pdf')
        #  parser.add_argument('--day', type=int, help='day to plot', default = None)
        parser.add_argument('--conf', '-c', help='use configuration FILE.yaml (default is ~/.config/babackup/server.yaml or /etc/babackup/server.yaml)')
        parser.add_argument('--name', '-N', help='use backup NAME, or define new backup NAME')
        parser.add_argument('--daily-to-keep', help='how many daily backups to keep', type=int, default=-1)
        parser.add_argument('--weekly-to-keep', help='how many weekly backups to keep', type=int, default=-1)
        parser.add_argument('--monthly-to-keep', help='how many monthly backups to keep', type=int, default=-1)
        parser.add_argument('--triannual-to-keep', help='how many triannual backups to keep', type=int, default=-1)
        parser.add_argument('--new-server-path', help='location of a new backup')
        parser.add_argument('--new-pub-key', help='public ssh key for new backup')
        parser.add_argument('--new-mode', help='new backup is rrsync or ssh')
        parser.add_argument('--check-current', help='check for fresh backups (defaults on)', default=True)
        parser.add_argument('--no-check-current', help='do not check for fresh backups', dest='check_current', action='store_false')
        parser.add_argument('--check-archive', help='check archives for aging, even if no new backups', action='store_true', default=None)
        parser.add_argument('--no-check-archive', help='check archives for aging, even if no new backups', dest='check_archive', action='store_false')
        parser.add_argument('--debug', '-d', help='debugging mode', action='store_true', default=False)
        parser.add_argument('--verbose', '-v', action='count', default=0)
        args = parser.parse_args()
        self.debug = args.debug
        self.verbose = args.verbose
        self.conf_path = args.conf
        self.name = args.name
        self.to_keep = dict()
        self.to_keep['daily'] = args.daily_to_keep
        self.to_keep['weekly'] = args.weekly_to_keep
        self.to_keep['monthly'] = args.monthly_to_keep
        self.to_keep['triannual'] = args.triannual_to_keep
        self.new_server_path = args.new_server_path
        self.new_pub_key = args.new_pub_key
        self.new_mode = args.new_mode
        self.check_current = args.check_current
        self.check_archive = args.check_archive
        self.DEFAULT_TO_KEEP = 10
        self.temp_dir = None
        return args


    def read_conf(self):
        """figure out what configuration file we're using, then read and return it"""
         # where?
        self.conf_dir = "/etc/babackup"
        if os.getuid() != 0:
            self.conf_dir = os.path.expanduser("~") + "/.config/babackup"
        # what?
        if self.conf_path is None:
            self.conf_path = self.conf_dir + "/server.yaml"
        # read it
        try:
            conf_stream = open(self.conf_path, 'r')
            self.conf = yaml.safe_load(conf_stream)
        except IOError:
            self.conf = dict()
        if self.conf.get('backups') is None:
            self.conf['backups'] = []
        # also set up logging
        self.configure_logging()


    def write_conf(self):
        if not os.path.isdir(self.conf_dir):
            os.makedirs(self.conf_dir, mode=0o755)
        with open(self.conf_path, 'w+') as conf_stream:
            yaml.dump(self.conf, conf_stream)


    def check_crontab(self, program, location, suggestion):
        """see if the user has a cron for PROGRAM
        If not, remind them to do SUGGESTION at LOCATION"""
        result = subprocess_run_capture_output(['crontab', '-l'])
        show_message = False
        if result.returncode == 1:
            # error 1 is no crontab
            show_message = True
        elif result.returncode == 0:
            with io.StringIO(result.stdout) as crontab_stream:
                for line in crontab_stream:
                    if line.startswith("#"):
                        continue
                    fields = line.split()
                    if len(fields) >= 5 and fields[5].endswith(program):
                        # hit
                        return
            show_message = True
        else:
            # ignore other errors
            pass
        if show_message:
            print(f"To automate {program}, add this crontab entry (crontab -e)\non the {location}:\n\n\t{suggestion}\n\n")


    def check_rrsync(self):
        """see if we have rrsync installed"""
        if os.path.exists(self.RRSYNC_PATH):
            return
        print(f"babackup_server requires rrsync exist in {self.RRSYNC_PATH}\n")
        typical_rrsync_source_path = "/usr/share/doc/rsync/support/rrsync"
        if os.path.exists(typical_rrsync_source_path):
            print(f"\nPlease install it by:\n\n\tcp {typical_rrsync_source_path} {self.RRSYNC_PATH}\n\tchmod +x {self.RRSYNC_PATH}\n")
        else:
            print(f"\nPlease copy rrsync from the support directory of rsync into {self.RRSYNC_PATH}\n")


    def configure_new_backup(self):
        """configure a new backup
Update configuration files, generate keys, say what to do on the server, etc.
        """
        self.read_conf()
        if self.name is None:
            sys.exit("babackup_server: attempt to add new backup without specifying --name")
        if self.new_mode is None:
            sys.exit("babackup_server: attempt to add new backup without specifying --new-mode=MODE (rrsync or ssh)")
        mode = self.new_mode
        if not (mode == 'local' or mode == 'ssh' or mode == 'rrsync'):
            sys.exit("babackup_server: unknown --new-mode={mode}\n")
        if self.new_server_path is None:
            sys.exit("babackup: attempt to add new backup without specifying --new-server-path=server/partial/or/full/path")
        if mode == 'rrsync' and self.new_pub_key is None:
            sys.exit("babackup_server: attempt to add new rrsync backup without specifying --new-pub-key='ssh-foo BASE64 keyname'")

        # and some sanity checking, since we're going to put stuff in authorized_keys
        if re.search(r"\s", self.new_server_path):
            sys.exit("babackup_server: rejecting --new-server-path that contains whitespace")
        if self.new_pub_key is not None and self.new_pub_key.find("\n") >= 0:
            sys.exit("babackup_server: rejecting --new-pub-key that contains newline")

        name = self.name
        backup = dict()
        backup["name"] = self.name
        backup["server_path"] = self.new_server_path
        backup["mode"] = mode
        if mode == 'rrsync':
            backup["pub_key"] = self.new_pub_key

        #
        # create the target directory
        #
        full_new_server_path = self.new_server_path
        if full_new_server_path[0] == '~':
            full_new_server_path = os.path.expanduser(full_new_server_path)
        if not os.path.isdir(full_new_server_path):
            self.verbose_log(f"babackup_server: mkdir {full_new_server_path}", 1)
            if not self.debug:
                os.mkdir(full_new_server_path)
                os.mkdir(full_new_server_path + "/current")
                os.mkdir(full_new_server_path + "/current/last")
                os.mkdir(full_new_server_path + "/current/last/data")

        #
        # change authorized_keys
        #
        if mode == 'rrsync':
            auth = 'command="' + self.RRSYNC_PATH + ' -wo ' + self.new_server_path + '/current",no-agent-forwarding,no-port-forwarding,no-pty,no-user-rc,no-X11-forwarding ' + self.new_pub_key
            self.verbose_log(f"babackup_server: adding public key to ~/.ssh/authorized_keys, with rrsync\n\t{auth}", 1)
            if not self.debug:
                auth_path = os.path.expanduser("~") + "/.ssh/authorized_keys"
                with open(auth_path, "a") as auth_stream:
                    auth_stream.write(auth + "\n")

        self.check_crontab("babackup_server", "server", "1,16,31,46 * * * * /usr/sbin/babackup_server")
        self.check_rrsync()

        if self.conf['backups'] is None:
            self.conf['backups'] = []
        self.conf['backups'].append(backup)
        if self.debug:
            return
        self.write_conf()


        
    def check_backup(self, backup):
        """check one backup with configuration BACKUP"""

        name = backup.get("name")
        if name is None:
            sys.exit("babackup: backup is missing 'name:'")
        server_path = backup.get("server_path")
        if server_path is None:
            sys.exit(f"babackup_server: backup {name} has no path")
        if server_path[0] == "~":
            server_path = os.path.expanduser(server_path)

        #
        # see if a this one finished a backup since last check
        #
        # 0. no begin: never started
        # 1. begin no end => in progress (or failed)
        # 2. begin and end, but end before begin => missed it and it's running again
        # 3. begin and end, but end after begin => good!
        #

        # no backup
        if not os.path.exists(f"{server_path}/current/begin"):
            self.verbose_log(f"babackup_server: {name} backup idle", 1)
            return

        if not os.path.exists(f"{server_path}/current/end"):
            self.verbose_log(f"babackup_server: {name} backup is active", 1)
            return

        begin_mtime = os.path.getmtime(f"{server_path}/current/begin")
        end_mtime = os.path.getmtime(f"{server_path}/current/end")
        if begin_mtime > end_mtime:
            self.verbose_log(f"babackup_server: {name} backup was complete but is active again", 2)
            return

        #
        # commit!
        #
        # (Note that there is a test-to-use race going on here :-( )
        #
        self.verbose_log(f"babackup_server: {name} is rolling current to last", 1)
        if not self.debug:
            if not os.path.exists(f"{server_path}/new"):
                os.mkdir(f"{server_path}/new")
            if os.path.exists(f"{server_path}/current"):
                os.rename(f"{server_path}/current", f"{server_path}/new/last")
            os.rename(f"{server_path}/new", f"{server_path}/current")

        #
        # move the old last into archive
        #
        # Note that we trust the file mtime rather than contents,
        # since the contents came from the user.
        #
        # Only do this if that backup looks good (has begin and end files).
        # Otherwise we let that last linger, eventually to be garbage collected
        # when current/last goes away
        #
        if os.path.isdir(f"{server_path}/current/last/last") and os.path.exists(f"{server_path}/current/last/last/begin") and os.path.exists(f"{server_path}/current/last/last/end"):
            last_begin_mtime = os.path.getmtime(f"{server_path}/current/last/last/begin")
            last_isotime = datetime.datetime.fromtimestamp(last_begin_mtime, datetime.timezone.utc).isoformat(timespec = 'minutes')
            self.verbose_log(f"babackup_server: {name} is moving last/last to archive/{last_isotime}", 2)
            # get rid of : in time, to be more filename friendly
            last_isotime = last_isotime.replace(":", "_")
            if not self.debug:
                if not os.path.isdir(f"{server_path}/archive"):
                    os.mkdir(f"{server_path}/archive")
                if os.path.isdir(f"{server_path}/archive/{last_isotime}"):
                    # xxx: we have two things with the same time, so the rename will fail.
                    pass
                else:
                    os.rename(f"{server_path}/current/last/last", f"{server_path}/archive/{last_isotime}")
            # we want to come back to this archive and check it later
            backup['archive_check'] = True

    def rmtree(self, path):
        """like shutil.rmtree, but wrapped to handle errors"""

        def redo_with_write(func, path, excinfo):
            os.chmod(path, stat.S_IWRITE)
            func(path)
        
        (major, minor, patchlevel) = list(map(int, platform.python_version_tuple()))
        if (major == 3 and minor >= 12) or major > 3:
            shutil.rmtree(path, onexc = redo_with_write)
        else:
            shutil.rmtree(path, onerror = redo_with_write)
        
            
    def check_backup_archive(self, backup):
        """check the archive of one backup (configuration BACKUP) for outdated entries"""

        name = backup.get("name")
        if name is None:
            sys.exit("babackup: backup is missing 'name:'")
        server_path = backup.get("server_path")
        if server_path[0] == "~":
            server_path = os.path.expanduser(server_path)
        if server_path is None:
            sys.exit(f"babackup_server: backup {name} has no path")
        if not os.path.isdir(f"{server_path}/archive"):
            self.verbose_log(f"babackup_server: {name} check has no archive ({server_path}/archive)", 2)
            return

        self.verbose_log(f"babackup_server: {name} check archive", 2)

        #
        # walk the archives
        # finding things to delete (ending ~)
        # and to date check.
        #
        path_to_remove = []
        to_date_check = []
        part_to_timestamp = dict()
        timestamp_to_part = dict()

        iso_matcher = re.compile(r"^\d{4}-?\d{2}-?\d{2}([tT]\d{2}[-:_]?\d{2})?([-+]\d{2}[-:_]\d{2})?$")
        for part in os.listdir(f"{server_path}/archive"):
            if part.endswith("~"):
                self.verbose_log(f"babackup_server: {name} check finds corpse {part}", 2)
                path_to_remove.append(f"{server_path}/archive/{part}")
            elif iso_matcher.match(part):
                to_date_check.append(part)
                # undo our prior : removal
                clean_part = part.replace("_", ":")
#                # seems to require seconds, also, for fromisofromat with timezone to work
#                if len(clean_part) == 22:
#                    clean_part = clean_part[:16] + ":00" + clean_part[16:]
                part_to_timestamp[part] = datetime.datetime.fromisoformat(clean_part).timestamp()
                timestamp_to_part[part_to_timestamp[part]] = part
            else:
                if self.verbose > 1:
                    self.verbose_log(f"babackup_server: {name} check finds suprising file {part} in {server_path}/archive", 3)

        #
        # apply the priorization algorithm
        #
        to_keep = dict()
        keep_periods = ['daily', 'weekly', 'monthly', 'triannual'] 
        keep_period_interval = [23*60*60, 7*86400-3600, 28*86400-3600*4, 28*86400*4-3600*8]
        if self.to_keep is None:
            self.to_keep = dict()
        for period in keep_periods:
            proposal = self.to_keep.get(period, self.DEFAULT_TO_KEEP)
            if proposal < 0:
                proposal = self.DEFAULT_TO_KEEP
            to_keep[period] = proposal

        newest_to_oldest_timestamps = sorted(part_to_timestamp.values(), reverse=True)
        keeping_what = keep_periods[0]
        keep_periods = keep_periods[1:]
        count_to_keep = to_keep[keeping_what]
        distance_too_close = keep_period_interval[0]
        keep_period_interval = keep_period_interval[1:]
        last_timestamp_kept = None
        timestamps_to_keep = []
        timestamps_to_retire = []
        for timestamp in newest_to_oldest_timestamps:
            if count_to_keep > 0 and ((last_timestamp_kept is None) or (last_timestamp_kept - timestamp > distance_too_close)):
                # keep this timestamp
                self.verbose_log(f"babackup_server: {name} keep {keeping_what} " + timestamp_to_part[timestamp], 2)
                timestamps_to_keep.append(timestamp)
                last_timestamp_kept = timestamp
                count_to_keep -= 1
                if count_to_keep <= 0:
                    if len(keep_periods) <= 0:
                    	keeping_what = 'no_more'
                    else:
                        keeping_what = keep_periods[0]
                        keep_periods = keep_periods[1:]
                        count_to_keep = to_keep[keeping_what]
                        distance_too_close = keep_period_interval[0]
                        keep_period_interval = keep_period_interval[1:]
            else:
                # too close, so drop it
                self.verbose_log(f"babackup_server: {name} aging {keeping_what} " + timestamp_to_part[timestamp], 1)
                timestamps_to_retire.append(timestamp)

        #
        # take action
        #
        # first rename them (easy)
        for timestamp in timestamps_to_retire:
            part = timestamp_to_part[timestamp]
            self.verbose_log(f"babackup_server: {name} will retire {server_path}/archive/{part}", 1)
            if not self.debug:
                os.rename(f"{server_path}/archive/{part}", f"{server_path}/archive/{part}~")
            path_to_remove.append(f"{server_path}/archive/{part}~")
        # then remove them
        for old_path in path_to_remove:
            self.verbose_log(f"babackup_server: {name} is removing tree {old_path}", 2)
            if not self.debug:
                self.rmtree(old_path)
            
        
    def check_backups(self):
        """check all backups (or whatever was specified with -N)"""
        self.read_conf()
        if self.check_current:
            for backup in self.conf['backups']:
                if (self.name is None or self.name == backup['name']) and backup.get("enabled", True):
                    self.check_backup(backup)
        #
        # Now that we've checked each, go back and look at aging.
        #
        force_check_archive = False
        if self.check_archive is not None:
            force_check_archive = self.check_archive
        for backup in self.conf['backups']:
            check_this_archive = False
            if backup.get("archive_check"):
                check_this_archive = True
            elif force_check_archive:
                check_this_archive = True
                if self.name is not None and self.name != backup.get('name', ''):
                    check_this_archive = False
            if check_this_archive:
                self.check_backup_archive(backup)


                
if __name__ == '__main__':
    Program()
    sys.exit(0)

