#!/usr/bin/env python3

# Copyright (C) 2019  Braiins Systems s.r.o.
#
# This file is part of Braiins Open-Source Initiative (BOSI).
#
# BOSI is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
# Please, keep in mind that we may also license BOSI or any part thereof
# under a proprietary license. For more information on the terms and conditions
# of such proprietary license or if you have any other questions, please
# contact us at opensource@braiins.com.

import argparse
import logging
import os
import sys

from getpass import getpass
from glob import glob
from urllib.request import Request, urlopen

import bos_toolbox.platform.backup as backup

from bos_toolbox.batch import read_hosts
from bos_toolbox.cache import WebCacheContext, WebCache
from bos_toolbox.common import BOS_PLUS_URl
from bos_toolbox.platform import (
    get_platform_from_ssh,
    OsType,
    OsMode,
    Platform,
    PlatformStop,
)
from bos_toolbox.transfer import wait_for_port
from bos_utils.ssh import SSHManager, SSHError

LOG = logging.getLogger(__name__)

USERNAME = 'root'
PASSWORD = None

REBOOT_DELAY = (3, 8)


class RestoreStop(Exception):
    pass


class FirmwareInfo:
    def __init__(self, web_cache: WebCache, url, name: str):
        self._web_cache = web_cache
        self.url = url
        self.name = name

    def extract_to_tmp_dir(self):
        return self._web_cache.download_and_extract(self.name, self.url)


class RestoreCache:
    FEEDS_PREFIX = 'toolbox_stock_'
    STOCK_DIR = 'stock'

    def __init__(self, web_cache: WebCache, server_url):
        self._web_cache = web_cache
        self.server_url = server_url.rstrip('/')
        self.target_variants = {}

    def extract_backup(self, file_path):
        return self._web_cache.extract(file_path, file_path)

    def remove_backup(self, file_path):
        self._web_cache.remove_tmp_cache(file_path)

    def get_feeds_url(self, target: str):
        return '/'.join([self.server_url, self.FEEDS_PREFIX + target])

    def _load_fw_variants(self, target: str):
        feeds_url = self.get_feeds_url(target)
        fw_variants = {}
        with urlopen(
            Request(feeds_url, headers={'User-Agent': 'Mozilla/5.0'})
        ) as response:
            LOG.info("Downloading firmware list '{}'...".format(feeds_url))
            for line in response.readlines():
                fw_variant, fw_file = line.decode().split('\t')
                fw_file = fw_file.strip()
                fw_url = '/'.join([self.server_url, self.STOCK_DIR, fw_file])
                fw_variants[fw_variant] = FirmwareInfo(self._web_cache, fw_url, fw_file)
        return fw_variants

    def get_fw_variants(self, target: str):
        fw_variants = self.target_variants.get(target)
        if not fw_variants:
            fw_variants = self._load_fw_variants(target)
            self.target_variants[target] = fw_variants
        return fw_variants

    def get_firmware_info(self, target: str, variant: str) -> FirmwareInfo:
        fw_variants = self.get_fw_variants(target)
        fw_info = fw_variants.get(variant)
        if not fw_info:
            LOG.error(
                "Stock firmware for platform '{}' (variant '{}') is not available!".format(
                    target, variant
                )
            )
            raise RestoreStop
        return fw_info


def check_compatibility(platform: Platform):
    if not platform:
        LOG.error('Stock firmware is being uninstalled from unsupported platform!')
        raise RestoreStop

    if platform.os_info.type != OsType.BOS:
        LOG.error('Remote target is running different firmware then Braiins OS!')
        raise RestoreStop


def _get_backup_dir(cache: RestoreCache, args):
    backup_path = args.backup_path
    if backup_path and not os.path.isdir(backup_path):
        LOG.info('Extracting backup tarball...')
        backup_dir = cache.extract_backup(backup_path)
        uenv_path = glob(os.path.join(backup_dir, '*', 'uEnv.txt'))
        if not uenv_path:
            LOG.error('Invalid backup tarball!')
            raise RestoreStop
        backup_path = os.path.split(uenv_path[0])[0]

    return backup_path


def _reboot_to_factory(host, ssh, mtdparts_params):
    ssh.run('fw_setenv', backup.RECOVERY_MTDPARTS[:-1], '"{}"'.format(mtdparts_params))
    ssh.run('miner', 'run_recovery')
    # continue after miner is in the recovery mode
    print('Rebooting to recovery...', end='')
    wait_for_port(host, 22, REBOOT_DELAY)


def uninstall(cache: RestoreCache, args, host, username, password):
    LOG.info('Connecting to {}...'.format(host))
    with SSHManager(host, username, password, load_host_keys=False) as ssh:
        platform = get_platform_from_ssh(ssh)
        # check compatibility of remote system
        check_compatibility(platform)

        bos_mode = platform.os_info.mode
        LOG.info('Detected BOS mode: {}'.format(bos_mode.name))

        backup_dir = _get_backup_dir(cache, args)
        if args.nand_restore:
            firmware_dir = None
        else:
            firmware = cache.get_firmware_info(
                platform.triple.target_full_name, str(platform.triple.variant)
            )

            LOG.info("Preparing stock firmware '{}'...".format(firmware.name))
            firmware_dir = firmware.extract_to_tmp_dir()

        mtdparts_params = platform.get_factory_mtdparts(args, ssh, backup_dir)
        mtdparts = list(backup.parse_mtdparts(mtdparts_params))

        if bos_mode == OsMode.sd or bos_mode == OsMode.recovery:
            # restore firmware from SD or recovery mode
            platform.restore_firmware(args, ssh, backup_dir, mtdparts, firmware_dir)
            return
        # reboot miner to recovery mode with target MTD parts
        _reboot_to_factory(host, ssh, mtdparts_params)

    LOG.info('Reconnecting to {}...'.format(host))
    with SSHManager(host, USERNAME, PASSWORD, load_host_keys=False) as ssh:
        platform = get_platform_from_ssh(ssh)
        check_compatibility(platform)
        bos_mode = platform.os_info.mode
        if bos_mode != OsMode.recovery:
            LOG.error('Could not reboot to recovery mode!')
            raise RestoreStop
        # restore firmware from recovery mode
        platform.restore_firmware(args, ssh, backup_dir, mtdparts, firmware_dir)


def main(parser, args):
    hosts = read_hosts(args.hosts)
    batch = len(hosts) > 1
    if batch:
        # user is not handled at all since we need root
        # ssh wrapper may ask for password based on it's own logic, we just provide default
        password = args.password or getpass('Default password: ') or PASSWORD
    else:
        password = args.password or PASSWORD

    if batch and args.backup_path:
        # Custom backups contain device mac address,
        # restoring this en-masse may not be a good idea
        parser.error('positional argument backup_path: not allowed in batch mode')
    if batch and args.nand_restore:
        parser.error('argument --nand_restoreh: not allowed in batch mode')
    if args.nand_restore and not args.backup_path:
        # factory image is not set and standard NAND restore is used
        parser.error('argument --nand_restore: missing positional argument backup_path')

    feeds_url = args.feeds_url or BOS_PLUS_URl

    with WebCacheContext() as web_cache:
        cache = RestoreCache(web_cache, feeds_url)
        for host in hosts:
            uninstall(cache, args, host, USERNAME, password)


def build_arg_parser(parser):
    parser.description = 'Uninstall Braiins OS[+] from the mining machine'

    parser.add_argument(
        'hosts',
        nargs='?',
        help='hostname or path to file with hosts of miners with Braiins OS',
    )
    parser.add_argument(
        'backup_path',
        nargs='?',
        help='path to directory or tgz file with data for miner restore',
    )
    parser.add_argument('-p', '--password', default='', help='administration password')
    parser_feeds = parser.add_mutually_exclusive_group()
    parser_feeds.add_argument(
        '--feeds-url', nargs='?', help='override default feeds server URL'
    )
    parser_feeds.add_argument(
        '--nand-restore',
        action='store_true',
        help='use full NAND restore from previous backup',
    )


if __name__ == '__main__':
    # execute only if run as a script
    parser = argparse.ArgumentParser()
    build_arg_parser(parser)
    # parse command line arguments
    args = parser.parse_args(sys.argv[1:])

    try:
        main(parser, args)
    except SSHError as e:
        LOG.error(str(e))
        sys.exit(1)
    except RestoreStop:
        sys.exit(2)
    except PlatformStop:
        sys.exit(3)
