#!/usr/bin/env python3

# Copyright (C) 2020  Braiins Systems s.r.o.
#
# This file is part of Braiins Open-Source Initiative (BOSI).
#
# BOSI is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
# Please, keep in mind that we may also license BOSI or any part thereof
# under a proprietary license. For more information on the terms and conditions
# of such proprietary license or if you have any other questions, please
# contact us at opensource@braiins.com.

import argparse
import base58
import logging
import os
import subprocess
import sys
import tarfile
import tempfile

from urllib.request import Request, urlopen

import bos_toolbox.platform.backup as backup
import bos_utils.hwid as hwid

from bos_toolbox.batch import read_hosts
from bos_toolbox.cache import WebCacheContext, WebCache
from bos_toolbox.common import (
    BOS_URl,
    BOS_PLUS_URl,
    BOS_PLUS_NIGHTLY_URl,
    BOS_REFERRAL_URL,
    STAGE3_BUILTIN_SRC_DIR,
)
from bos_toolbox.platform import get_platform_from_ssh, OsType, Platform, PlatformStop
from bos_toolbox.transfer import upload_local_files, wait_for_port, Progress
from bos_toolbox.util import get_payload_path, get_refid
from bos_utils.packages import Packages
from bos_utils.ssh import SSHManager, SSHError

from .unlock import unlock, UnlockStop, BATCH_PASSWORD_RETRIES
from .update import update, UpdateFail

LOG = logging.getLogger(__name__)

# default credentials for connecting to device. will ask password if they wont do
USERNAME = 'root'
PASSWORD = 'admin'

# binaries necessary for upgrade will be copied onto host from here (sftp, fw_printenv)
SYSTEM_DIR = 'system'
# recovery partition files?
SOURCE_DIR = 'firmware'
# path to auxiliary scripts
PAYLOAD_DIR = get_payload_path()
# stage3 stagging dir, to be tarballed for transfer
TARGET_DIR = '/tmp/firmware'

STAGE3_DIR = 'upgrade'
STAGE3_BUILTIN_DIR = 'upgrade/builtin'
STAGE3_USER_DIR = 'upgrade/usr'
STAGE3_FILE = 'stage3.tgz'
STAGE3_SCRIPT = 'stage3.sh'
STAGE3_REFID = 'bos_refid'
STAGE3_REFERRAL = 'referral'
STAGE3_BOS_MGMT_ID = 'bos_mgmt_id'

REBOOT_DELAY = (3, 5)


class UpgradeStop(Exception):
    pass


class FirmwareInfo:
    PLUS_SIGNATURE = '-plus'

    def __init__(self, web_cache: WebCache, url, version: str, name: str, ext: str):
        self._web_cache = web_cache
        self.url = url
        self.version = version
        self.name = name
        self.ext = ext

    @property
    def cache_name(self):
        return self.name + self.ext

    def is_plus(self):
        return self.PLUS_SIGNATURE in self.name

    def extract_to_tmp_dir(self):
        extracted_dir = self._web_cache.download_and_extract(self.cache_name, self.url)
        return os.path.join(extracted_dir, self.name)


class PackageInfo:
    def __init__(
        self, web_cache: WebCache, root_url, version: str, name: str, filename
    ):
        self._web_cache = web_cache
        self.root_url = root_url
        self.version = version
        self.name = name
        self.filename = filename

    @property
    def url(self):
        return '/'.join([self.root_url, self.filename])

    def download(self):
        return self._web_cache.download(self.filename, self.url)


class FirmwareCache:
    FEEDS_PREFIX = 'toolbox_bos_'
    REFERRAL_INDEX = 'Packages'
    FILE_EXT = '.tar.gz'

    def __init__(
        self, web_cache: WebCache, feeds_url, referral_url, fw_version: str, refid: str
    ):
        self._web_cache = web_cache
        self.feeds_url = feeds_url.rstrip('/')
        self.referral_url = referral_url.rstrip('/')
        self.fw_version = fw_version
        self.refid = refid
        self.target_list = {}
        self.referral_list = []

    def get_feeds_url(self, target: str):
        return '/'.join([self.feeds_url, self.FEEDS_PREFIX + target])

    def _load_fw_list(self, target: str):
        feeds_url = self.get_feeds_url(target)
        fw_list = []
        with urlopen(
            Request(feeds_url, headers={'User-Agent': 'Mozilla/5.0'})
        ) as response:
            LOG.info("Downloading firmware list '{}'...".format(feeds_url))
            for line in response.readlines():
                fw_version, fw_file = line.decode().split()
                if not fw_file.endswith(self.FILE_EXT):
                    LOG.error(
                        "Unexpected file extension of firmware '{}'!".format(fw_file)
                    )
                    raise UpgradeStop
                fw_name = fw_file[: -len(self.FILE_EXT)]
                fw_url = '/'.join([self.feeds_url, fw_version, fw_file])
                fw_list.append(
                    FirmwareInfo(
                        self._web_cache, fw_url, fw_version, fw_name, self.FILE_EXT
                    )
                )
        return fw_list

    def get_fw_list(self, target: str):
        fw_list = self.target_list.get(target)
        if not fw_list:
            fw_list = self._load_fw_list(target)
            self.target_list[target] = fw_list
        return fw_list

    def get_firmware_info(self, target: str) -> FirmwareInfo:
        fw_version = self.fw_version
        fw_list = self.get_fw_list(target)
        fw_info = next(
            (x for x in fw_list if not fw_version or x.version == fw_version), None
        )
        if not fw_info:
            msg = "firmware '{}'".format(fw_version) if fw_version else 'any firmware'
            LOG.error("Cannot get {} for platform '{}'!".format(msg, target))
            raise UpgradeStop
        return fw_info

    def get_referral_url(self, filename=None):
        url_list = [self.referral_url, self.refid]
        if filename:
            url_list.append(filename)
        return '/'.join(url_list)

    def _load_referral_list(self):
        referral_index_url = self.get_referral_url(self.REFERRAL_INDEX)
        referral_list = {}
        with urlopen(
            Request(referral_index_url, headers={'User-Agent': 'Mozilla/5.0'})
        ) as response:
            LOG.info("Preparing referral program with id '{}'...".format(self.refid))
            lines = (line.decode() for line in response)
            referral_url = self.get_referral_url()
            for package in Packages(None, lines):
                prev_package = referral_list.get(package.name)
                if not prev_package or prev_package.version < package.version:
                    referral_list[package.name] = PackageInfo(
                        self._web_cache,
                        referral_url,
                        package.version,
                        package.name,
                        package.filename,
                    )
        return list(referral_list.values())

    def init_referral(self):
        if not self.refid:
            return

        self.referral_list = self._load_referral_list()


def cleanup_system(ssh, platform):
    LOG.info('Cleaning remote system...')
    platform.cleanup_system(ssh)


def check_stage3_path(path):
    if not os.path.isdir(path):
        LOG.error(
            "Post-upgrade path '{}' is missing or is not a directory!".format(path)
        )
        raise UpgradeStop
    if not os.path.isfile(os.path.join(path, STAGE3_SCRIPT)):
        LOG.error("Script '{}' is missing in '{}'!".format(STAGE3_SCRIPT, path))
        raise UpgradeStop


def check_bos_mgmt_id(id):
    valid = False
    try:
        valid = len(base58.b58decode_check(id.encode())) == 12
    except ValueError:
        pass
    if not valid:
        LOG.error("Invalid BOS management id '{}'!".format(id))
        raise UpgradeStop


def try_unlock(host: str, web_password, batch: bool) -> bool:
    try:
        params = {}
        if web_password:
            params['password'] = web_password
        if batch:
            params['password_retries'] = BATCH_PASSWORD_RETRIES
        unlock(host, **params)
    except UnlockStop:
        return False
    else:
        return True


def main(parser, args):
    stage3_user_path = args.post_upgrade
    stage3_builtin_path = None

    if stage3_user_path:
        check_stage3_path(stage3_user_path)
    if os.path.isdir(STAGE3_BUILTIN_SRC_DIR):
        check_stage3_path(STAGE3_BUILTIN_SRC_DIR)
        stage3_builtin_path = STAGE3_BUILTIN_SRC_DIR

    if args.bos_mgmt_id:
        check_bos_mgmt_id(args.bos_mgmt_id)

    if args.feeds_url:
        feeds_url = args.feeds_url
    elif args.open_source:
        feeds_url = BOS_URl
    elif args.nightly:
        feeds_url = BOS_PLUS_NIGHTLY_URl
    else:
        feeds_url = BOS_PLUS_URl

    referral_url = args.referral_url or BOS_REFERRAL_URL

    # user is not handled at all since we need root
    # ssh wrapper may ask for password based on it's own logic, we just provide default
    password = args.password or PASSWORD

    hosts = read_hosts(args.hosts)
    batch = len(hosts) > 1

    error = None
    with WebCacheContext() as web_cache:
        fw_cache = FirmwareCache(
            web_cache, feeds_url, referral_url, args.fw_version, get_refid()
        )
        fw_cache.init_referral()
        for host in hosts:
            curr_error = None
            tried_unlock = False
            while True:
                try:
                    install(
                        args,
                        fw_cache,
                        host,
                        batch,
                        USERNAME,
                        password,
                        stage3_user_path,
                        stage3_builtin_path,
                    )
                except SSHError as ex:
                    if tried_unlock:
                        curr_error = UpgradeStop
                        LOG.error(str(ex))
                    elif try_unlock(host, password, batch):
                        # miner has been successfully unlocked and set default Antminer password
                        # try again ssh connection with default Antminer password...
                        password = PASSWORD
                        tried_unlock = True
                        continue
                except (UpgradeStop, UpdateFail, PlatformStop) as ex:
                    curr_error = ex
                break
            # do not stop batch mode when one host fails
            if curr_error and batch:
                LOG.error('Skipping host {}!'.format(host))
            if curr_error and not error:
                # store first error
                error = curr_error
    if error:
        raise error


def install(
    args,
    fw_cache: FirmwareCache,
    host,
    batch,
    username,
    password,
    stage3_user_path,
    stage3_builtin_path,
):
    bos_mgmt_id = args.bos_mgmt_id

    with SSHManager(host, username, password, load_host_keys=False) as ssh:
        platform = get_platform_from_ssh(ssh)
        if not platform:
            LOG.error('Braiins OS is being installed on unsupported platform!')
            raise UpgradeStop

        # check compatibility of remote system
        if platform.os_info.type == OsType.BOS:
            # the system already runs BOS so call update instead
            LOG.info(
                "Remote target is already running Braiins OS '{}'".format(
                    platform.os_info.version
                )
            )
            LOG.info('Running standard system update...')
            update(ssh)
            LOG.info('Update finished successfully!')
            return

        firmware = fw_cache.get_firmware_info(platform.triple.target_full_name)
        firmware_plus = '+' if firmware.is_plus() else ''

        LOG.info(
            'Preparing Braiins OS{} {} ({})...'.format(
                firmware_plus, firmware.version, firmware.name
            )
        )
        firmware_dir = firmware.extract_to_tmp_dir()

        LOG.info('Installing system to {}...'.format(host))

        if args.backup:
            mac = platform.probe.get_mac()
            backup_dir = backup.get_output_dir(mac)
            if not platform.backup_firmware(args, ssh, backup_dir, mac):
                raise UpgradeStop

        # prepare target directory
        ssh.run('rm', '-fr', TARGET_DIR)
        ssh.run('mkdir', '-p', TARGET_DIR)

        # upgrade remote system with missing utilities
        LOG.info('Preparing remote system...')
        platform.prepare_system(ssh, os.path.join(firmware_dir, SYSTEM_DIR))

        # copy firmware files to the server over SFTP
        sftp = ssh.open_sftp()
        sftp.chdir(TARGET_DIR)
        LOG.info('Uploading firmware...')
        upload_local_files(sftp, os.path.join(firmware_dir, SOURCE_DIR), '')
        if (
            stage3_user_path
            or stage3_builtin_path
            or fw_cache.refid
            or fw_cache.referral_list
            or bos_mgmt_id
        ):
            LOG.info('Uploading post-upgrade (stage3)...')
            with tempfile.TemporaryDirectory() as stage3_dir:
                stage3_file = os.path.join(stage3_dir, STAGE3_FILE)
                with tarfile.open(stage3_file, 'w:gz') as stage3:
                    if stage3_user_path:
                        stage3.add(stage3_user_path, STAGE3_USER_DIR)
                    if stage3_builtin_path:
                        stage3.add(stage3_builtin_path, STAGE3_BUILTIN_DIR)
                    stage3.add(
                        os.path.join(PAYLOAD_DIR, STAGE3_SCRIPT),
                        arcname='/'.join([STAGE3_DIR, STAGE3_SCRIPT]),
                    )
                    if fw_cache.refid:
                        bos_refid_path = os.path.join(stage3_dir, STAGE3_REFID)
                        with open(bos_refid_path, 'w') as file:
                            file.write('{}\n'.format(fw_cache.refid))
                        stage3.add(
                            bos_refid_path, arcname='/'.join([STAGE3_DIR, STAGE3_REFID])
                        )
                    for referral in fw_cache.referral_list:
                        referral_ipk_path = referral.download()
                        stage3.add(
                            referral_ipk_path,
                            arcname='/'.join(
                                [STAGE3_DIR, STAGE3_REFERRAL, referral.filename]
                            ),
                        )
                    if bos_mgmt_id:
                        bos_mgmt_id_path = os.path.join(stage3_dir, STAGE3_BOS_MGMT_ID)
                        with open(bos_mgmt_id_path, 'w') as file:
                            file.write('{}\n'.format(bos_mgmt_id))
                        stage3.add(
                            bos_mgmt_id_path,
                            arcname='/'.join([STAGE3_DIR, STAGE3_BOS_MGMT_ID]),
                        )
                    stage3.close()
                    with Progress(
                        STAGE3_FILE, os.path.getsize(stage3_file)
                    ) as progress:
                        sftp.put(stage3_file, STAGE3_FILE, callback=progress)
        sftp.close()

        # generate HW identifier for miner
        hw_id = hwid.generate()

        # get other stage1 parameters
        if args.psu_power_limit == 0:
            # 0 is special parameter for disabling autotuning
            psu_power_limit = ''
        else:
            psu_power_limit = args.psu_power_limit or 'default'

        if args.keep_hostname:
            keep_hostname = 'yes'
        elif args.no_keep_hostname:
            keep_hostname = 'no'
        else:
            # keep only user defined hostname and skip factory one (default behaviour)
            keep_hostname = 'cond'

        pool_user = args.pool_user or ''
        keep_network = 'no' if args.no_keep_network else 'yes'
        keep_pools = 'no' if args.no_keep_pools else 'yes'
        auto_upgrade = 'no' if args.no_auto_upgrade else 'yes'
        dry_run = 'yes' if args.dry_run else 'no'

        # run stage1 upgrade process
        try:
            LOG.info('Upgrading firmware...')
            stdout, _ = ssh.run(
                'cd',
                TARGET_DIR,
                '&&',
                'ls',
                '-l',
                '&&',
                "/bin/sh stage1.sh '{}' '{}' '{}' '{}' '{}' '{}' '{}' '{}'".format(
                    hw_id,
                    pool_user,
                    psu_power_limit,
                    keep_network,
                    keep_hostname,
                    keep_pools,
                    auto_upgrade,
                    dry_run,
                ),
            )
        except subprocess.CalledProcessError as error:
            cleanup_system(ssh, platform)
            LOG.error('Error log:')
            for line in error.stderr.readlines():
                print(line, end='')
            raise UpgradeStop
        else:
            if args.dry_run:
                cleanup_system(ssh, platform)
                LOG.info('Dry run of upgrade was successful!')
                return

            for line in stdout.readlines():
                print(line, end='')
            LOG.info('Upgrade was successful!')
            print('Rebooting...', end='')
            try:
                ssh.run('/sbin/reboot')
            except subprocess.CalledProcessError:
                # reboot returns exit status -1
                pass

    # do not wait when install is called in batch mode
    if not batch:
        if args.no_wait:
            print()
            LOG.info(
                'Wait for 120 seconds before the system becomes fully operational!'
            )
        else:
            wait_for_port(host, 80, REBOOT_DELAY)


def build_arg_parser(parser):
    parser.description = 'Install Braiins OS[+] onto a mining machine'
    parser.add_argument(
        'hosts',
        nargs='?',
        help='hostname or path to file with hosts of miners with original firmware',
    )
    parser_feeds = parser.add_mutually_exclusive_group()
    parser_feeds.add_argument(
        '--open-source',
        action='store_true',
        help='use for installation open source version',
    )
    parser_feeds.add_argument(
        '--nightly', action='store_true', help='use for installation nightly version'
    )
    parser_feeds.add_argument(
        '--feeds-url', nargs='?', help='override default feeds server URL'
    )
    parser_feeds.add_argument('--referral-url', nargs='?', help=argparse.SUPPRESS)
    parser.add_argument(
        '--fw-version', nargs='?', help='select specific firmware version'
    )
    parser.add_argument(
        '--backup', action='store_true', help='do miner backup before upgrade'
    )
    parser.add_argument(
        '--no-auto-upgrade',
        action='store_true',
        help='turn off auto-upgrade of installed firmware',
    )
    parser.add_argument(
        '--no-nand-backup',
        action='store_true',
        help='skip full NAND backup (config is still being backed up)',
    )
    parser.add_argument(
        '--pool-user', nargs='?', help='set username and workername for default pool'
    )
    parser.add_argument(
        '--psu-power-limit', nargs='?', type=int, help='set PSU power limit (in watts)'
    )
    parser.add_argument(
        '--no-keep-network',
        action='store_true',
        help='do not keep miner network configuration (use DHCP)',
    )
    parser.add_argument(
        '--no-keep-pools',
        action='store_true',
        help='do not keep miner pool configuration',
    )
    parser.add_argument(
        '--no-keep-hostname',
        action='store_true',
        help='do not keep miner hostname and generate new one based on MAC',
    )
    parser.add_argument(
        '--keep-hostname', action='store_true', help='force to keep any miner hostname'
    )
    parser.add_argument(
        '--no-wait',
        action='store_true',
        help='do not wait until system is fully upgraded',
    )
    parser.add_argument(
        '--dry-run',
        action='store_true',
        help='do all upgrade steps without actual upgrade',
    )
    parser.add_argument(
        '--post-upgrade', nargs='?', help='path to directory with stage3.sh script'
    )
    parser.add_argument(
        '--bos-mgmt-id', nargs='?', help='set BOS management identifier'
    )
    parser.add_argument('-p', '--password', help='administration password')


if __name__ == '__main__':
    # execute only if run as a script
    parser = argparse.ArgumentParser()
    build_arg_parser(parser)
    # parse command line arguments
    args = parser.parse_args(sys.argv[1:])

    try:
        main(parser, args)
    except KeyboardInterrupt:
        print()
        sys.exit(1)
    except SSHError as e:
        LOG.error(str(e))
        sys.exit(1)
    except UpgradeStop:
        sys.exit(2)
    except PlatformStop:
        sys.exit(3)
