#!/usr/bin/env python3

# Copyright (C) 2020  Braiins Systems s.r.o.
#
# This file is part of Braiins Open-Source Initiative (BOSI).
#
# BOSI is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
# Please, keep in mind that we may also license BOSI or any part thereof
# under a proprietary license. For more information on the terms and conditions
# of such proprietary license or if you have any other questions, please
# contact us at opensource@braiins.com.

import argparse
import io
import logging
import os
import shutil
import sys
import time

from getpass import getpass
from subprocess import CalledProcessError

from bos_toolbox.batch import read_hosts
from bos_toolbox.common import (
    BOS_REFERRAL_URL,
    BOS_REFERRAL_FEEDS_NAME,
    BOS_REFERRAL_PKG_NAME,
    BOS_CUSTOM_FEEDS_CFG,
    BOS_FIRMWARE_PKG_NAME,
)
from bos_toolbox.util import get_refid
from bos_utils.ssh import SSHManager

LOG = logging.getLogger(__name__)

USERNAME = 'root'


class UpdateStop(Exception):
    pass


def main(parser, args):
    hosts = read_hosts(args.hosts)
    batch = len(hosts) > 1
    if batch:
        # ssh wrapper may ask for password based on it's own logic, we just provide default
        password = args.password or getpass('Default password: ') or ''
    else:
        password = args.password or ''

    error_count = 0
    for host in hosts:
        try:
            update_one(host, password, args.package)
        except UpdateFail as ex:
            error_count += 1
            if not args.ignore:
                raise
            LOG.error(f'Updating {host} failed: {ex}')
        except CalledProcessError as ex:
            error_count += 1
            print(ex.stdout.read().decode())
            print(ex.stderr.read().decode())
            if not args.ignore:
                raise UpdateFail(f'process returned {ex.returncode}')
            LOG.error(f'Updating {host} failed ({ex.returncode})')
        except Exception as ex:
            error_count += 1
            if not args.ignore:
                raise
            LOG.error(f'Updating {host} failed ({ex})')

    if error_count:
        sys.exit(f'{error_count} errors encountered')


def prepare_referral(ssh):
    refid = get_refid()
    if not refid:
        return

    record = f'src/gz {BOS_REFERRAL_FEEDS_NAME} {BOS_REFERRAL_URL}/{refid}\n'
    stream = io.StringIO()
    with ssh.open(BOS_CUSTOM_FEEDS_CFG) as file:
        for line in file:
            if (
                record
                and not line.startswith('#')
                and f' {BOS_REFERRAL_FEEDS_NAME} ' in line
            ):
                line = record
                record = None
            stream.write(line)
    record and stream.write(record)
    stream.seek(0)

    with ssh.open(BOS_CUSTOM_FEEDS_CFG, mode='w') as file:
        shutil.copyfileobj(stream, file)


def update(ssh, package=BOS_FIRMWARE_PKG_NAME):
    stdout, stderr = None, None
    commands = ['opkg update', f'opkg install {package}']
    while commands:
        try:
            stdout, stderr = ssh.run(commands[0])
            commands.pop(0)
        except CalledProcessError as ex:
            error_msg = ex.stderr.read().decode().strip()
            std_msg = ex.stdout.read().decode().strip()
            if package == BOS_FIRMWARE_PKG_NAME and 'Running system upgrade' in std_msg:
                # ignore error because this signals start of system upgrade
                return
            if '/var/lock/opkg.lock' in error_msg:
                # try again until resources are available
                time.sleep(0.5)
                continue
            if std_msg:
                # dump out what was received on stdout, just in case
                print(std_msg)
            raise UpdateFail(error_msg)

    std_msg = stdout and stdout.read().decode().rstrip()
    if std_msg:
        print(std_msg)
    error_msg = stderr and stderr.read().decode().strip()
    if error_msg:
        raise UpdateFail(error_msg)


def update_one(host, password, package):
    LOG.info(f'Updating {package} on {host}...')
    with SSHManager(host, USERNAME, password, load_host_keys=False) as ssh:
        if package == BOS_REFERRAL_PKG_NAME:
            prepare_referral(ssh)
        update(ssh, package)


class UpdateFail(RuntimeError):
    pass


def build_arg_parser(parser):
    parser.description = (
        'Update system package on mining machines running Braiins OS[+]'
    )

    parser.add_argument(
        'hosts',
        nargs='?',
        help='hostname or path to file with hosts of miners with Braiins OS',
    )
    parser.add_argument(
        'package',
        nargs='?',
        default=BOS_FIRMWARE_PKG_NAME,
        help='Package name for update',
    )
    parser.add_argument('-p', '--password', default='', help='administration password')
    parser.add_argument('-i', '--ignore', action='store_true', help='no halt on errors')


if __name__ == '__main__':
    try:
        parser = argparse.ArgumentParser()
        build_arg_parser(parser)
        args = parser.parse_args()
        main(parser, args)
    except KeyboardInterrupt:
        sys.exit(1)
    except UpdateStop:
        sys.exit(2)
    except Exception as ex:
        sys.exit(f'error: {ex}')
