# Copyright (C) 2020  Braiins Systems s.r.o.
#
# This file is part of Braiins Open-Source Initiative (BOSI).
#
# BOSI is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
# Please, keep in mind that we may also license BOSI or any part thereof
# under a proprietary license. For more information on the terms and conditions
# of such proprietary license or if you have any other questions, please
# contact us at opensource@braiins.com.

import io
import logging
import os
import shutil
import subprocess
import tarfile
import time

from tempfile import TemporaryDirectory

from .backup import ssh_factory_mtdparts, ssh_backup, ssh_restore, ssh_restore_reboot
from .backup import get_stream_size, get_default_hostname
from .common import Platform, PlatformStop
from .descriptor import TargetTriple, Target, SubTarget, OsInfo
from .detect import Probe

from bos_toolbox.transfer import upload_local_files, Progress
from bos_toolbox.util import get_payload_path

LOG = logging.getLogger(__name__)


class PlatformAntminer(Platform):
    BACKUP_SUFFIX = '_tmp'

    CONFIG_TAR = 'config.tar.gz'
    TARGET_DIR = '/tmp/bitmain_fw'

    RESTORE_DIR = get_payload_path('restore')

    def __init__(
        self,
        target: Target,
        sub_target: SubTarget,
        variant,
        os_info: OsInfo,
        probe: Probe,
    ):
        super().__init__(TargetTriple(target, sub_target, variant), os_info, probe)

    def get_default_factory_mtdparts(self):
        return None

    def get_system_binaries(self):
        return []

    def get_system_links(self):
        return []

    def get_firmware_path(self, firmware_dir):
        return firmware_dir

    def get_restore_name(self):
        return '{}_{}.sh'.format(self.triple.target.name, self.triple.sub_target.name)

    def backup_firmware(self, args, ssh, path, mac):
        LOG.info('Preparing system for backup...')
        # before NAND dump try to stop all daemons which modify UBIFS
        # /tmp on AntMiner is mounted directly to UBIFS
        ssh.run('mount', '-t', 'tmpfs', 'tmpfs', '/tmp/')
        # stop bmminer which logs to /tmp
        ssh.run('/etc/init.d/bmminer.sh', 'stop')
        # give to system some time to kill all processes and free handles
        time.sleep(1)
        # sync everything to NAND
        ssh.run('sync')
        LOG.info('Backuping configuration files...')
        local_path = os.path.join(path, self.CONFIG_TAR)
        with open(local_path, 'wb') as local_file, ssh.pipe(
            'tar', 'cvzf', '-', '/config'
        ) as remote:
            shutil.copyfileobj(remote.stdout, local_file)
        # start backup process
        return ssh_backup(args, ssh, path, mac)

    def upload_bitmain_files(self, sftp, firmware_dir):
        # transfer original Bitmain firmware images needed for upgrade
        get_firmware_path = self.get_firmware_path(firmware_dir)
        upload_local_files(sftp, get_firmware_path, '')

    def get_factory_mtdparts(self, args, ssh, backup_dir):
        if backup_dir:
            return ssh_factory_mtdparts(args, ssh, backup_dir)
        else:
            return self.get_default_factory_mtdparts()

    def restore_bitmain_firmware(self, _args, ssh, backup_dir, firmware_dir):
        # prepare target directory
        ssh.run('rm', '-fr', self.TARGET_DIR)
        ssh.run('mkdir', '-p', self.TARGET_DIR)

        # copy firmware files to the server over SFTP
        sftp = ssh.open_sftp()
        sftp.chdir(self.TARGET_DIR)

        LOG.info('Uploading firmware...')
        self.upload_bitmain_files(sftp, firmware_dir)

        LOG.info('Uploading restore scripts...')
        restore_name = self.get_restore_name()
        files = [(backup_dir, self.CONFIG_TAR), (self.RESTORE_DIR, restore_name)]
        for dir, file_name in files:
            local_path = os.path.join(dir, file_name)
            with Progress(local_path, info=file_name) as progress:
                sftp.put(local_path, file_name, callback=progress)

        sftp.close()

        # run stage1 upgrade process
        try:
            LOG.info('Restoring firmware...')
            stdout, _ = ssh.run(
                'cd',
                self.TARGET_DIR,
                '&&',
                'ls',
                '-l',
                '&&',
                '/bin/sh {}'.format(restore_name),
            )
        except subprocess.CalledProcessError as error:
            for line in error.stderr.readlines():
                print(line, end='')
            raise PlatformStop
        else:
            for line in stdout.readlines():
                print(line, end='')

    def create_bitmain_config(self, ssh, tmp_dir):
        bitmain_hostname = 'antMiner'
        config_dir = 'config'

        # restore original configuration from running miner
        mac = self.probe.get_mac()

        config_path = os.path.join(tmp_dir, self.CONFIG_TAR)
        tar = tarfile.open(config_path, 'w:gz')
        stream_info = tar.gettarinfo(config_path)

        # create mac file
        stream = io.BytesIO('{}\n'.format(mac).encode())
        stream_info.name = '{}/mac'.format(config_dir)
        stream_info.size = get_stream_size(stream)
        tar.addfile(stream_info, stream)
        stream.close()

        # create network.conf file
        stream = io.BytesIO()
        net_proto = self.probe.get_net_proto()
        if net_proto == 'dhcp':
            net_hostname = self.probe.get_net_hostname()
            if net_hostname == get_default_hostname(mac):
                # do not restore BOS default hostname
                net_hostname = bitmain_hostname
            stream.write('hostname={}\n'.format(net_hostname).encode())
            stream.write('dhcp=true\n'.encode())
        else:
            # static protocol
            net_ipaddr = self.probe.get_net_ipaddr()
            net_mask = self.probe.get_net_mask()
            net_gateway = self.probe.get_net_gateway()
            net_dns = self.probe.get_net_dns()
            stream.write('hostname={}\n'.format(bitmain_hostname).encode())
            stream.write('ipaddress={}\n'.format(net_ipaddr).encode())
            stream.write('netmask={}\n'.format(net_mask).encode())
            stream.write('gateway={}\n'.format(net_gateway).encode())
            stream.write('dnsservers="{}"\n'.format(net_dns).encode())
        stream.seek(0)
        stream_info.name = '{}/network.conf'.format(config_dir)
        stream_info.size = get_stream_size(stream)
        tar.addfile(stream_info, stream)
        stream.close()

        tar.close()

    def restore_firmware(self, args, ssh, backup_dir, mtdparts, firmware_dir):
        os_mode = self.os_info.mode
        if firmware_dir:
            if backup_dir:
                self.restore_bitmain_firmware(args, ssh, backup_dir, firmware_dir)
            else:
                with TemporaryDirectory() as tmp_dir:
                    LOG.info('Creating configuration files...')
                    self.create_bitmain_config(ssh, tmp_dir)
                    self.restore_bitmain_firmware(args, ssh, tmp_dir, firmware_dir)
            ssh_restore_reboot(ssh, os_mode)
        else:
            # use default NAND dump restore
            ssh_restore(ssh, backup_dir, mtdparts, os_mode)

    def prepare_system(self, ssh, path):
        system_binaries = self.get_system_binaries()
        for file_name, remote_path in system_binaries:
            remote_file_name = '{}/{}'.format(remote_path, file_name)
            try:
                ssh.run('test', '!', '-e', remote_file_name)
            except subprocess.CalledProcessError:
                LOG.error(
                    "File '{}' exists on remote target already!".format(
                        remote_file_name
                    )
                )
                raise PlatformStop

        for file_name, remote_path in system_binaries:
            ssh.run('mkdir', '-p', remote_path)
            remote_file_name = '{}/{}'.format(remote_path, file_name)
            LOG.info('Copy {} to {}'.format(file_name, remote_file_name))
            ssh.put(os.path.join(path, file_name), remote_file_name)
            ssh.run('chmod', '+x', remote_file_name)

        for link_name, remote_path in self.get_system_links():
            try:
                ssh.run('test', '!', '-e', link_name)
            except subprocess.CalledProcessError:
                ssh.run('mv', link_name, link_name + self.BACKUP_SUFFIX)
            ssh.run('mkdir', '-p', os.path.dirname(link_name))
            ssh.run('ln', '-fs', remote_path, link_name)

        print()

    def cleanup_system(self, ssh):
        for file_name, remote_path in self.get_system_binaries():
            remote_file_name = '{}/{}'.format(remote_path, file_name)
            ssh.run('rm', '-r', remote_file_name)

        for link_name, remote_path in self.get_system_links():
            try:
                ssh.run('test', '!', '-e', link_name + self.BACKUP_SUFFIX)
            except subprocess.CalledProcessError:
                ssh.run('mv', link_name + self.BACKUP_SUFFIX, link_name)
            else:
                ssh.run('rm', link_name)
