#!/usr/bin/python3

#
# babackup_server
#
# Copyright (C) 2024 by John Heidemann <johnh@isi.edu>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License,
# version 2, as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA.
#

# pylint: disable=line-too-long, trailing-whitespace, trailing-newlines, no-else-return, fixme, invalid-name


import argparse
import sys
import os
import platform
import os.path
import stat
import io
import re
import shutil
import tempfile
import datetime
import subprocess
import logging
import pdb
# pdb.set_trace()

import yaml


def subprocess_run_capture_output(cmd):
    """backwards compatible subprocess.run for capture_output"""
    (major, minor, _) = list(map(int, platform.python_version_tuple()))
    if (major == 3 and minor >= 7) or major > 3:
        return subprocess.run(cmd, capture_output = True, encoding = 'utf-8', check=False)
    else:
        return subprocess.run(cmd, encoding = 'utf-8', stdout = subprocess.PIPE, stderr = subprocess.PIPE, check=False)


def datetime_datetime_fromisoformat_compat(s):
    """backwards comparable datetime.datetime.fromisoformat for python 3.6"""
    (major, minor, _) = list(map(int, platform.python_version_tuple()))
    if (major == 3 and minor >= 7) or major > 3:
        return datetime.datetime.fromisoformat(s)
    # our poor version
    m = re.match(r"^(\d{4})-?(\d{2})-?(\d{2})([tT]?(\d{2})[-:_]?(\d{2}))?([-+]\d{2}[-:_]\d{2})?$", s)
    if not m:
        # we assume we are only called with good dates
        sys.exit("babackup_server: gave up on parsing {s}")
    zeroed_groups = []
    for i in range(0,6):
        if m.group(i+1) is not None and i != 3:
            s = m.group(i+1).lstrip("tT")
            zeroed_groups.append(int(s))
        else:
            zeroed_groups.append(1)
    dt = datetime.datetime(zeroed_groups[0], zeroed_groups[1], zeroed_groups[2], zeroed_groups[4], zeroed_groups[5], tzinfo = datetime.timezone.utc)
    return dt


class Program:
    RRSYNC_PATH = "/usr/local/bin/rrsync"
    
    def __init__(self):
        # if assertion_fails:
        #     raise Exception("Assertion failed")
        # or better sys.exit("Assertion failed")
        self.parse_args()
        if self.new_server_path is not None:
            self.configure_new_backup()
        else:
            self.check_backups()


    def verbose_log(self, s, verbosity = 1):
        """our common logging function, both to stdout and the log"""
        if self.verbose >= verbosity:
            print(s)
        if verbosity == 1:
            logging.info(s)
        elif verbosity >= 2:
            logging.debug(s)


    def configure_logging(self):
        """set up formal logging, as given in the config"""
        logging_path = None
        logging_conf = self.conf.get('logging')
        if logging_conf is not None and logging_conf.get("filename") is not None:
            logging_path = self.conf['logging']['filename']
        if logging_path is None:
            logging_path = self.conf_dir + "/server.log"
        logging_dirname = os.path.dirname(logging_path)
        if not os.path.isdir(logging_dirname):
            os.path.makedirs(logging_dirname)
        logging.basicConfig(filename  = logging_path, level = 'INFO', format="%(asctime)s: %(message)s")


    def parse_args(self):
        """parse all the args"""
        parser = argparse.ArgumentParser(description = 'backup things via rsync to a remote server from a client', epilog="""
babackup

        """)
        # see https://docs.python.org/3/library/argparse.html
        #  ArgumentParser.add_argument(name or flags...[, action][, nargs][, const][, default][, type][, choices][, required][, help][, metavar][, dest])

        #  parser.add_argument('--focus', help='focus on a given TARGET', choices=['us', 'nynj', 'coverage'], default='us')
        #  parser.add_argument('--output', '-o', help='output FILE')
        #  parser.add_argument('--duty-cycle', help='duty cycle (a float)', type=float)
        #  parser.add_argument('--type', '-t', choices=['pdf', 'png'], help='type of output (pdf or png)', default = 'pdf')
        #  parser.add_argument('--day', type=int, help='day to plot', default = None)
        parser.add_argument('--conf', '-c', help='use configuration FILE.yaml (default is ~/.config/babackup/server.yaml or /etc/babackup/server.yaml)')
        parser.add_argument('--name', '-N', help='use backup NAME, or define new backup NAME')
        parser.add_argument('--daily-to-keep', help='how many daily backups to keep', type=int, default=-1)
        parser.add_argument('--weekly-to-keep', help='how many weekly backups to keep', type=int, default=-1)
        parser.add_argument('--monthly-to-keep', help='how many monthly backups to keep', type=int, default=-1)
        parser.add_argument('--triannual-to-keep', help='how many triannual backups to keep', type=int, default=-1)
        parser.add_argument('--new-server-path', help='location of a new backup')
        parser.add_argument('--new-pub-key', help='public ssh key for new backup')
        parser.add_argument('--new-mode', help='new backup is rrsync or ssh')
        parser.add_argument('--check-current', help='check for fresh backups (defaults on)', default=True)
        parser.add_argument('--no-check-current', help='do not check for fresh backups', dest='check_current', action='store_false')
        parser.add_argument('--check-archive', help='check archives for aging, even if no new backups', action='store_true', default=None)
        parser.add_argument('--no-check-archive', help='check archives for aging, even if no new backups', dest='check_archive', action='store_false')
        parser.add_argument('--status', help='show status', action='store_true', default=False)
        parser.add_argument('--debug', '-d', help='debugging mode', action='store_true', default=False)
        parser.add_argument('--verbose', '-v', action='count', default=0)
        args = parser.parse_args()
        self.debug = args.debug
        self.verbose = args.verbose
        self.conf_path = args.conf
        self.name = args.name
        self.to_keep = {}
        self.to_keep['daily'] = args.daily_to_keep
        self.to_keep['weekly'] = args.weekly_to_keep
        self.to_keep['monthly'] = args.monthly_to_keep
        self.to_keep['triannual'] = args.triannual_to_keep
        self.new_server_path = args.new_server_path
        self.new_pub_key = args.new_pub_key
        self.new_mode = args.new_mode
        self.check_current = args.check_current
        self.check_archive = args.check_archive
        self.DEFAULT_TO_KEEP = 10
        self.temp_dir = None
        self.status = args.status
        if self.status:
            self.debug = True
            if self.verbose < 2:
                self.verbose = 2
        return args


    def read_conf(self):
        """figure out what configuration file we're using, then read and return it"""
         # where?
        self.conf_dir = "/etc/babackup"
        if os.getuid() != 0:
            self.conf_dir = os.path.expanduser("~") + "/.config/babackup"
        # what?
        if self.conf_path is None:
            self.conf_path = self.conf_dir + "/server.yaml"
        # read it
        try:
            with open(self.conf_path, 'r', encoding='utf-8') as conf_stream:
                self.conf = yaml.safe_load(conf_stream)
        except IOError:
            self.conf = {}
        if self.conf.get('backups') is None:
            self.conf['backups'] = []
        # also set up logging
        self.configure_logging()


    def write_conf(self):
        """write the config file"""
        if not os.path.isdir(self.conf_dir):
            os.makedirs(self.conf_dir, mode=0o755)
        with open(self.conf_path, 'w+', encoding='utf-8') as conf_stream:
            yaml.dump(self.conf, conf_stream)


    def check_crontab(self, program, location, suggestion):
        """see if the user has a cron for PROGRAM
        If not, remind them to do SUGGESTION at LOCATION"""
        result = subprocess_run_capture_output(['/usr/bin/crontab', '-l'])
        show_message = False
        if result.returncode == 1:
            # error 1 is no crontab
            show_message = True
        elif result.returncode == 0:
            with io.StringIO(result.stdout) as crontab_stream:
                for line in crontab_stream:
                    if line.startswith("#"):
                        continue
                    fields = line.split()
                    if len(fields) >= 5 and fields[5].endswith(program):
                        # hit
                        return
            show_message = True
        else:
            # ignore other errors
            pass
        if show_message:
            print(f"To automate {program}, add this crontab entry (crontab -e)\non the {location}:\n\n\t{suggestion}\n\n")


    def check_rrsync(self):
        """see if we have rrsync installed"""
        if os.path.exists(self.RRSYNC_PATH):
            return
        print(f"babackup_server requires rrsync exist in {self.RRSYNC_PATH}\n")
        typical_rrsync_source_path = "/usr/share/doc/rsync/support/rrsync"
        if os.path.exists(typical_rrsync_source_path):
            print(f"\nPlease install it by:\n\n\tcp {typical_rrsync_source_path} {self.RRSYNC_PATH}\n\tchmod +x {self.RRSYNC_PATH}\n")
        else:
            print(f"\nPlease copy rrsync from the support directory of rsync into {self.RRSYNC_PATH}\n")


    def configure_new_backup(self):
        """configure a new backup
Update configuration files, generate keys, say what to do on the server, etc.
        """
        self.read_conf()
        if self.name is None:
            sys.exit("babackup_server: attempt to add new backup without specifying --name")
        if self.new_mode is None:
            sys.exit("babackup_server: attempt to add new backup without specifying --new-mode=MODE (rrsync or ssh)")
        mode = self.new_mode
        if not mode in ('local', 'ssh', 'rrsync'):
            sys.exit("babackup_server: unknown --new-mode={mode}\n")
        if self.new_server_path is None:
            sys.exit("babackup: attempt to add new backup without specifying --new-server-path=server/partial/or/full/path")
        if mode == 'rrsync' and self.new_pub_key is None:
            sys.exit("babackup_server: attempt to add new rrsync backup without specifying --new-pub-key='ssh-foo BASE64 keyname'")

        # and some sanity checking, since we're going to put stuff in authorized_keys
        if re.search(r"\s", self.new_server_path):
            sys.exit("babackup_server: rejecting --new-server-path that contains whitespace")
        if self.new_pub_key is not None and self.new_pub_key.find("\n") >= 0:
            sys.exit("babackup_server: rejecting --new-pub-key that contains newline")

        name = self.name
        backup = {}
        backup["name"] = self.name
        backup["server_path"] = self.new_server_path
        backup["mode"] = mode
        if mode == 'rrsync':
            backup["pub_key"] = self.new_pub_key

        #
        # create the target directory
        #
        full_new_server_path = self.new_server_path
        if full_new_server_path[0] == '~':
            full_new_server_path = os.path.expanduser(full_new_server_path)
        if not os.path.isdir(full_new_server_path):
            self.verbose_log(f"babackup_server: mkdir {full_new_server_path}", 1)
            if not self.debug:
                os.mkdir(full_new_server_path)
                os.mkdir(full_new_server_path + "/current")
                os.mkdir(full_new_server_path + "/current/last")
                os.mkdir(full_new_server_path + "/current/last/data")

        #
        # change authorized_keys
        #
        if mode == 'rrsync':
            auth = 'command="' + self.RRSYNC_PATH + ' -wo ' + self.new_server_path + '/current",no-agent-forwarding,no-port-forwarding,no-pty,no-user-rc,no-X11-forwarding ' + self.new_pub_key
            self.verbose_log(f"babackup_server: adding public key to ~/.ssh/authorized_keys, with rrsync\n\t{auth}", 1)
            if not self.debug:
                auth_path = os.path.expanduser("~") + "/.ssh/authorized_keys"
                with open(auth_path, "a", encoding='utf-8') as auth_stream:
                    auth_stream.write(auth + "\n")

        self.check_crontab("babackup_server", "server", "1,16,31,46 * * * * /usr/sbin/babackup_server")
        self.check_rrsync()

        if self.conf['backups'] is None:
            self.conf['backups'] = []
        self.conf['backups'].append(backup)
        if self.debug:
            return
        self.write_conf()


        
    def check_backup(self, backup):
        """check one backup with configuration BACKUP"""

        name = backup.get("name")
        if name is None:
            sys.exit("babackup: backup is missing 'name:'")
        server_path = backup.get("server_path")
        if server_path is None:
            sys.exit(f"babackup_server: backup {name} has no path")
        if server_path[0] == "~":
            server_path = os.path.expanduser(server_path)

        #
        # see if a this one finished a backup since last check
        #
        # 0. no begin: never started
        # 1. begin no end => in progress (or failed)
        # 2. begin and end, but end before begin => missed it and it's running again
        # 3. begin and end, but end after begin => good!
        #

        # no backup
        if not os.path.exists(f"{server_path}/current/begin"):
            self.verbose_log(f"babackup_server: {name} backup idle", 1)
            return

        if not os.path.exists(f"{server_path}/current/end"):
            self.verbose_log(f"babackup_server: {name} backup is active", 1)
            return

        begin_mtime = os.path.getmtime(f"{server_path}/current/begin")
        end_mtime = os.path.getmtime(f"{server_path}/current/end")
        if begin_mtime > end_mtime:
            self.verbose_log(f"babackup_server: {name} backup was complete but is active again", 2)
            return

        #
        # commit!
        #
        # (Note that there is a test-to-use race going on here :-( )
        #
        self.verbose_log(f"babackup_server: {name} is rolling current to last", 1)
        if not self.debug:
            if not os.path.exists(f"{server_path}/new"):
                os.mkdir(f"{server_path}/new")
            if os.path.exists(f"{server_path}/current"):
                os.rename(f"{server_path}/current", f"{server_path}/new/last")
            os.rename(f"{server_path}/new", f"{server_path}/current")

        #
        # move the old last into archive
        #
        # Note that we trust the file mtime rather than contents,
        # since the contents came from the user.
        #
        # Only do this if that backup looks good (has begin and end files).
        # Otherwise we let that last linger, eventually to be garbage collected
        # when current/last goes away
        #
        if os.path.isdir(f"{server_path}/current/last/last") and os.path.exists(f"{server_path}/current/last/last/begin") and os.path.exists(f"{server_path}/current/last/last/end"):
            last_begin_mtime = os.path.getmtime(f"{server_path}/current/last/last/begin")
            last_isotime = datetime.datetime.fromtimestamp(last_begin_mtime, datetime.timezone.utc).isoformat(timespec = 'minutes')
            self.verbose_log(f"babackup_server: {name} is moving last/last to archive/{last_isotime}", 2)
            # get rid of : in time, to be more filename friendly
            last_isotime = last_isotime.replace(":", "_")
            if not self.debug:
                if not os.path.isdir(f"{server_path}/archive"):
                    os.mkdir(f"{server_path}/archive")
                if os.path.isdir(f"{server_path}/archive/{last_isotime}"):
                    # xxx: we have two things with the same time, so the rename will fail.
                    pass
                else:
                    os.rename(f"{server_path}/current/last/last", f"{server_path}/archive/{last_isotime}")
            # we want to come back to this archive and check it later
            backup['archive_check'] = True

    def rmtree(self, path):
        """like shutil.rmtree, but wrapped to handle errors"""

        def redo_with_write(func, path, excinfo):
            st = os.stat(path)
            if (st.st_mode & stat.S_IWRITE) == 0:
                os.chmod(path, st.st_mode | stat.S_IWRITE)
            # and the directory
            parent = os.path.dirname(path)
            parent_st = os.stat(parent)
            if (parent_st.st_mode & stat.S_IWRITE) == 0:
                os.chmod(parent, parent_st.st_mode | stat.S_IWRITE)
            func(path)
        
        (major, minor, _) = list(map(int, platform.python_version_tuple()))
        if (major == 3 and minor >= 12) or major > 3:
            shutil.rmtree(path, onexc = redo_with_write)
        else:
            shutil.rmtree(path, onerror = redo_with_write)
        
            
    def check_backup_archive(self, backup):
        """check the archive of one backup (configuration BACKUP) for outdated entries"""

        name = backup.get("name")
        if name is None:
            sys.exit("babackup: backup is missing 'name:'")
        server_path = backup.get("server_path")
        if server_path[0] == "~":
            server_path = os.path.expanduser(server_path)
        if server_path is None:
            sys.exit(f"babackup_server: backup {name} has no path")
        if not os.path.isdir(f"{server_path}/archive"):
            self.verbose_log(f"babackup_server: {name} check has no archive ({server_path}/archive)", 2)
            return

        self.verbose_log(f"babackup_server: {name} check archive", 2)

        #
        # walk the archives
        # finding things to delete (ending ~)
        # and to date check.
        #
        path_to_remove = []
        to_date_check = []
        part_to_timestamp = {}
        timestamp_to_part = {}

        iso_matcher = re.compile(r"^(\d{4})-?(\d{2})-?(\d{2})([tT]\d{2}[-:_]?\d{2})?([-+]\d{2}[-:_]\d{2})?$")
        archives_listdir = os.listdir(f"{server_path}/archive")
        for part in sorted(archives_listdir, reverse=True):
            if part.endswith("~"):
                self.verbose_log(f"babackup_server: {name} check finds corpse {part}", 2)
                path_to_remove.append(f"{server_path}/archive/{part}")
            elif iso_matcher.match(part):
                to_date_check.append(part)
                # undo our prior : removal
                clean_part = part.replace("_", ":")
                part_to_timestamp[part] = datetime_datetime_fromisoformat_compat(clean_part).timestamp()
                timestamp_to_part[part_to_timestamp[part]] = part
            else:
                if self.verbose > 1:
                    self.verbose_log(f"babackup_server: {name} check finds suprising file {part} in {server_path}/archive", 3)

        #
        # apply the priorization algorithm
        #
        to_keep = {}
        keep_periods = ['daily', 'weekly', 'monthly', 'triannual'] 
        keep_period_interval = [23*60*60, 7*86400-3600, 28*86400-3600*4, 28*86400*4-3600*8]
        if self.to_keep is None:
            self.to_keep = {}
        for period in keep_periods:
            proposal = self.to_keep.get(period, self.DEFAULT_TO_KEEP)
            if proposal < 0:
                proposal = self.DEFAULT_TO_KEEP
            to_keep[period] = proposal

        newest_to_oldest_timestamps = sorted(part_to_timestamp.values(), reverse=True)
        count_to_keep = 0
        timestamps_to_keep = []
        timestamps_to_retire = []

        while True:
            if not newest_to_oldest_timestamps:
                break

            if count_to_keep <= 0:
                if len(keep_periods) <= 0:
                    keeping_what = 'no_more'
                else:
                    keeping_what = keep_periods[0]
                    keep_periods = keep_periods[1:]
                    count_to_keep = to_keep[keeping_what]
                    distance_too_close = keep_period_interval[0]
                    keep_period_interval = keep_period_interval[1:]
                    # have to keep the next one
                    last_timestamp_kept = None

            #
            # keep the first
            #
            timestamp = newest_to_oldest_timestamps.pop(0)
            self.verbose_log(f"babackup_server: {name} keep {keeping_what} " + timestamp_to_part[timestamp], 2)
            timestamps_to_keep.append(timestamp)
            count_to_keep -= 1

            #
            # suppress
            #
            suppress_until_timestamp = timestamp - distance_too_close
            suppress_queue = []
            while True:
                if not newest_to_oldest_timestamps:
                    break
                timestamp = newest_to_oldest_timestamps.pop(0)
                if timestamp < suppress_until_timestamp:
                    # too far, put it back
                    newest_to_oldest_timestamps.insert(0, timestamp)
                    break
                suppress_queue.append(timestamp)

            # keep the last one
            if suppress_queue:
                # put back the last one
                timestamp = suppress_queue.pop()
                newest_to_oldest_timestamps.insert(0, timestamp)
                for timestamp in suppress_queue:
                    self.verbose_log(f"babackup_server: {name} aging {keeping_what} " + timestamp_to_part[timestamp], 1)
                    timestamps_to_retire.append(timestamp)

        #
        # take action
        #
        # first rename them (easy)
        for timestamp in timestamps_to_retire:
            part = timestamp_to_part[timestamp]
            self.verbose_log(f"babackup_server: {name} will retire {server_path}/archive/{part}", 1)
            if not self.debug:
                os.rename(f"{server_path}/archive/{part}", f"{server_path}/archive/{part}~")
            path_to_remove.append(f"{server_path}/archive/{part}~")
        # then remove them, oldest first
        for old_path in sorted(path_to_remove):
            self.verbose_log(f"babackup_server: {name} is removing tree {old_path}", 2)
            if not self.debug:
                self.rmtree(old_path)
            
        
    def check_backups(self):
        """check all backups (or whatever was specified with -N)"""
        self.read_conf()
        if self.check_current:
            for backup in self.conf['backups']:
                if (self.name is None or self.name == backup['name']) and backup.get("enabled", True):
                    self.check_backup(backup)
        #
        # Now that we've checked each, go back and look at aging.
        #
        force_check_archive = False
        if self.check_archive is not None:
            force_check_archive = self.check_archive
        for backup in self.conf['backups']:
            check_this_archive = False
            if backup.get("archive_check"):
                check_this_archive = True
            elif force_check_archive:
                check_this_archive = True
                if self.name is not None and self.name != backup.get('name', ''):
                    check_this_archive = False
            if check_this_archive:
                self.check_backup_archive(backup)


                
if __name__ == '__main__':
    Program()
    sys.exit(0)

