#!/usr/bin/env nix-shell
#! nix-shell -i python3 --pure
#! nix-shell -p python3 python3Packages.coloredlogs python3Packages.progressbar2

# Handles sync-conflict files

import argparse
import logging
import os
import pickle
import re
import sys
import zlib

import coloredlogs
import progressbar

progressbar.streams.wrap_stderr()
coloredlogs.install(level="INFO", fmt="%(levelname)s %(message)s")
log = logging.getLogger()

# 1) Create file list with conflict files
# 2) Gather file informations (date, owner, size, checksum)
# 3) Propose what to do


def sizeof_fmt(num, suffix="B"):
    # Stolen from https://stackoverflow.com/a/1094933
    for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
        if abs(num) < 1024.0:
            return "%3.1f %s%s" % (num, unit, suffix)
        num /= 1024.0
    return "%.1f %s%s" % (num, "Yi", suffix)


class Table:
    def __init__(self, width, height):
        self.width = width
        self.height = height
        self.data = [["" for _ in range(self.height)] for _ in range(self.width)]

    def set(self, x, y, data):
        self.data[x][y] = str(data)

    def print(self):
        widths = [max([len(cell) for cell in column]) for column in self.data]
        for y in range(self.height):
            for x in range(self.width):
                cell = self.data[x][y]
                l = len(cell)
                width = widths[x]
                if x > 0:
                    cell = " | " + cell
                cell = cell + " " * (width - l)
                print(cell, end="\t")
            print()


class Database:
    VERSION = 1
    CONFLICT_PATTERN = re.compile("\.sync-conflict-\d{8}-\d{6}-\w{7}")

    def __init__(self, directory):
        self.version = Database.VERSION
        self.directory = directory
        self.data = dict()

    def prune(self):
        toPrune = list()
        for filepath, databaseFile in self.data.items():
            databaseFile.migrate()  # TODO Temp dev stuff
            databaseFile.prune()
            if not databaseFile.isRelevant():
                toPrune.append(filepath)
        for filepath in toPrune:
            del self.data[filepath]

    def nbFiles(self):
        return sum(databaseFile.nbFiles() for databaseFile in self.data.values())

    def totalSize(self):
        return sum(databaseFile.totalSize() for databaseFile in self.data.values())

    def maxSize(self):
        return sum(databaseFile.maxSize() for databaseFile in self.data.values())

    def totalChecksumSize(self):
        return sum(
            databaseFile.totalChecksumSize() for databaseFile in self.data.values()
        )

    def getList(self):
        self.prune()

        log.info("Finding conflict files")
        widgets = [
            progressbar.AnimatedMarker(),
            " ",
            progressbar.BouncingBar(),
            " ",
            progressbar.DynamicMessage("conflicts"),
            " ",
            progressbar.DynamicMessage("files"),
            " ",
            progressbar.DynamicMessage("dir", width=20, precision=20),
            " ",
            progressbar.Timer(),
        ]
        bar = progressbar.ProgressBar(widgets=widgets).start()
        f = 0
        for root, dirs, files in os.walk(self.directory):
            for conflictFilename in files:
                f += 1
                if not Database.CONFLICT_PATTERN.search(conflictFilename):
                    continue
                filename = Database.CONFLICT_PATTERN.sub("", conflictFilename)
                key = (root, filename)
                if key in self.data:
                    dataFile = self.data[key]
                else:
                    dataFile = DatabaseFile(root, filename)
                    self.data[key] = dataFile

                if filename in files:
                    dataFile.addConflict(filename)
                dataFile.addConflict(conflictFilename)

            bar.update(
                conflicts=len(self.data), files=f, dir=root[(len(self.directory) + 1) :]
            )
        bar.finish()
        log.info(
            f"Found {len(self.data)} conflicts, totalling {self.nbFiles()} conflict files."
        )

    def getStats(self):
        log.info("Getting stats from conflict files")
        bar = progressbar.ProgressBar(max_value=self.nbFiles()).start()
        f = 0
        for databaseFile in self.data.values():
            databaseFile.getStats()
            f += databaseFile.nbFiles()
            bar.update(f)
        bar.finish()
        log.info(
            f"Total file size: {sizeof_fmt(self.totalSize())}, possible save: {sizeof_fmt(self.totalSize() - self.maxSize())}"
        )

    def getChecksums(self):
        log.info("Checksumming conflict files")
        widgets = [
            progressbar.DataSize(),
            " of ",
            progressbar.DataSize("max_value"),
            " (",
            progressbar.AdaptiveTransferSpeed(),
            ") ",
            progressbar.Bar(),
            " ",
            progressbar.DynamicMessage("dir", width=20, precision=20),
            " ",
            progressbar.DynamicMessage("file", width=20, precision=20),
            " ",
            progressbar.Timer(),
            " ",
            progressbar.AdaptiveETA(),
        ]
        bar = progressbar.DataTransferBar(
            max_value=self.totalChecksumSize(), widgets=widgets
        ).start()
        f = 0
        for databaseFile in self.data.values():
            bar.update(
                f,
                dir=databaseFile.root[(len(self.directory) + 1) :],
                file=databaseFile.filename,
            )
            f += databaseFile.totalChecksumSize()
            try:
                databaseFile.getChecksums()
            except KeyboardInterrupt:
                return
            except BaseException as e:
                log.error(e, exc_info=True)
                pass
        bar.finish()

    def printDifferences(self):
        for databaseFile in self.data.values():
            print()
            databaseFile.printInfos(diff=True)

    def takeAction(self, execute=False, *args, **kwargs):
        for databaseFile in self.data.values():
            databaseFile.decideAction(*args, **kwargs)
            databaseFile.takeAction(execute=execute)


class DatabaseFile:
    BLOCK_SIZE = 4096
    RELEVANT_STATS = ("st_mode", "st_uid", "st_gid", "st_size", "st_mtime")

    def __init__(self, root, filename):
        self.root = root
        self.filename = filename
        self.stats = []
        self.conflicts = []
        self.checksums = []
        self.action = None
        log.debug(f"{self.root}/{self.filename} - new")

    def addConflict(self, conflict):
        if conflict in self.conflicts:
            return
        self.conflicts.append(conflict)
        self.stats.append(None)
        self.checksums.append(None)
        log.debug(f"{self.root}/{self.filename} - add: {conflict}")

    def migrate(self):
        # Temp dev stuff since I don't want to resum that whole 400 GiB dir
        if self.stats is None:
            self.stats = [None] * len(self.conflicts)
        try:
            if self.checksums is None:
                self.checksums = [None] * len(self.conflicts)
        except AttributeError:
            self.checksums = [None] * len(self.conflicts)

    def removeConflict(self, conflict):
        f = self.conflicts.index(conflict)
        del self.conflicts[f]
        del self.stats[f]
        del self.checksums[f]
        log.debug(f"{self.root}/{self.filename} - del: {conflict}")

    def getPath(self, conflict):
        return os.path.join(self.root, conflict)

    def getPaths(self):
        return [self.getPath(conflict) for conflict in self.conflicts]

    def prune(self):
        toPrune = list()
        for conflict in self.conflicts:
            if not os.path.isfile(self.getPath(conflict)):
                toPrune.append(conflict)

        if len(toPrune):
            for conflict in toPrune:
                self.removeConflict(conflict)

    def isRelevant(self):
        if len(self.conflicts) == 1:
            if self.conflicts[0] == self.filename:
                return False
        elif len(self.conflicts) < 1:
            return False
        else:
            return True

    def nbFiles(self):
        return len(self.conflicts)

    def totalSize(self):
        return sum((stat.st_size if stat is not None else 0) for stat in self.stats)

    def maxSize(self):
        return max((stat.st_size if stat is not None else 0) for stat in self.stats)

    def totalChecksumSize(self):
        size = 0
        for f, checksum in enumerate(self.checksums):
            if checksum is None:
                stat = self.stats[f]
                if stat is not None:
                    size += stat.st_size
        return size

    def getStats(self):
        for f, conflict in enumerate(self.conflicts):
            oldStat = self.stats[f]
            newStat = os.stat(self.getPath(conflict))
            oldChecksum = self.checksums[f]

            # If it's been already summed, and we have the same inode and same ctime, don't resum
            if (
                oldStat is None
                or not isinstance(oldChecksum, int)
                or oldStat.st_size != newStat.st_size
                or oldStat.st_dev != newStat.st_dev
                or oldStat.st_ino != newStat.st_ino
                or oldStat.st_ctime != newStat.st_ctime
                or oldStat.st_dev != newStat.st_dev
            ):
                self.checksums[f] = None

            self.stats[f] = newStat

        # If all the file are of different size, set as different files
        if len(self.stats) == len(set([s.st_size for s in self.stats])):
            self.checksums = [False] * len(self.conflicts)

        # If all the files are the same inode, set as same files
        if (
            len(set([s.st_ino for s in self.stats])) == 1
            and len(set([s.st_dev for s in self.stats])) == 1
        ):
            self.checksums = [True] * len(self.conflicts)

    def getChecksums(self):
        # TODO It's not even required to have a sum, this thing is not collision resistant now
        # TODO We might use BTRFS feature to know if conflict files are deduplicated between them

        filedescs = dict()
        for f, conflict in enumerate(self.conflicts):
            if self.checksums[f] is not None:
                continue
            self.checksums[f] = 1
            filedescs[f] = open(self.getPath(conflict), "rb")

        while len(filedescs):
            toClose = set()

            # Compute checksums for next block for all files
            for f, filedesc in filedescs.items():
                data = filedesc.read(DatabaseFile.BLOCK_SIZE)
                self.checksums[f] = zlib.adler32(data, self.checksums[f])
                if len(data) < DatabaseFile.BLOCK_SIZE:
                    toClose.add(f)

            # Stop summing as soon as checksum diverge
            for f in filedescs.keys():
                if self.checksums.count(self.checksums[f]) < 2:
                    toClose.add(f)

            for f in toClose:
                filedescs[f].close()
                del filedescs[f]

    def getFeatures(self):
        features = dict()
        features["name"] = self.conflicts
        features["sum"] = self.checksums
        for statName in DatabaseFile.RELEVANT_STATS:
            # Rounding beause I Syncthing also rounds
            features[statName] = [
                int(stat.__getattribute__(statName)) for stat in self.stats
            ]
        return features

    def getDiffFeatures(self):
        features = self.getFeatures()
        diffFeatures = dict()
        for key, vals in features.items():
            if len(set(vals)) > 1:
                diffFeatures[key] = vals
        return diffFeatures

    @staticmethod
    def shortConflict(conflict):
        match = Database.CONFLICT_PATTERN.search(conflict)
        if match:
            return match[0][15:]
        else:
            return "-"

    def printInfos(self, diff=True):
        print(os.path.join(self.root, self.filename))
        if diff:
            features = self.getDiffFeatures()
        else:
            features = self.getFeatures()
        features["name"] = [DatabaseFile.shortConflict(c) for c in self.conflicts]
        table = Table(len(features), len(self.conflicts) + 1)
        for x, featureName in enumerate(features.keys()):
            table.set(x, 0, featureName)
        for x, featureName in enumerate(features.keys()):
            for y in range(len(self.conflicts)):
                table.set(x, y + 1, features[featureName][y])
        table.print()

    def decideAction(self, mostRecent=False):
        # TODO More arguments for choosing
        reason = "undecided"
        self.action = None
        if len(self.conflicts) == 1:
            self.action = 0
            reason = "only file"
        else:
            features = self.getDiffFeatures()
            if len(features) == 1:
                reason = "same files"
                self.action = 0
            elif "st_mtime" in features and mostRecent:
                recentTime = features["st_mtime"][0]
                recentIndex = 0
                for index, time in enumerate(features["st_mtime"]):
                    if time > recentTime:
                        recentTime = time
                        recentIndex = 0
                self.action = recentIndex
                reason = "most recent"

        if self.action is None:
            log.warning(f"{self.root}/{self.filename}: skip, cause: {reason}")
        else:
            log.info(
                f"{self.root}/{self.filename}: keep {DatabaseFile.shortConflict(self.conflicts[self.action])}, cause: {reason}"
            )

    def takeAction(self, execute=False):
        if self.action is None:
            return
        actionName = self.conflicts[self.action]
        if actionName != self.filename:
            log.debug(
                f"Rename {self.getPath(actionName)} → {self.getPath(self.filename)}"
            )
            if execute:
                os.rename(self.getPath(actionName), self.getPath(self.filename))
        for conflict in self.conflicts:
            if conflict is actionName:
                continue
            log.debug(f"Delete {self.getPath(conflict)}")
            if execute:
                os.unlink(self.getPath(conflict))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Handle Syncthing's .sync-conflict files "
    )

    # Execution flow
    parser.add_argument(
        "directory", metavar="DIRECTORY", nargs="?", help="Directory to analyse"
    )
    parser.add_argument("-d", "--database", help="Database path for file informations")
    parser.add_argument(
        "-r",
        "--most-recent",
        action="store_true",
        help="Always keep the most recent version",
    )
    parser.add_argument(
        "-e", "--execute", action="store_true", help="Really apply changes"
    )
    parser.add_argument(
        "-p",
        "--print",
        action="store_true",
        help="Only print differences between files",
    )

    args = parser.parse_args()

    # Argument default values attribution
    if args.directory is None:
        args.directory = os.curdir
    args.directory = os.path.realpath(args.directory)

    # Create / load the database
    database = None
    if args.database:
        if os.path.isfile(args.database):
            try:
                with open(args.database, "rb") as databaseFile:
                    database = pickle.load(databaseFile)
                assert isinstance(database, Database)
            except BaseException as e:
                raise ValueError("Not a database file")
            assert (
                database.version <= Database.VERSION
            ), "Version of the loaded database is too recent"
            assert (
                database.directory == args.directory
            ), "Directory of the loaded database doesn't match"

    if database is None:
        database = Database(args.directory)

    def saveDatabase():
        if args.database:
            global database
            with open(args.database, "wb") as databaseFile:
                pickle.dump(database, databaseFile)

    database.getList()
    saveDatabase()

    database.getStats()
    saveDatabase()

    database.getChecksums()
    saveDatabase()

    if args.print:
        database.printDifferences()
    else:
        database.takeAction(mostRecent=args.most_recent, execute=args.execute)