# Copyright (C) 2020  Braiins Systems s.r.o.
#
# This file is part of Braiins Open-Source Initiative (BOSI).
#
# BOSI is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
# Please, keep in mind that we may also license BOSI or any part thereof
# under a proprietary license. For more information on the terms and conditions
# of such proprietary license or if you have any other questions, please
# contact us at opensource@braiins.com.

import re
import subprocess

from bos_utils.ssh import SSHManager

from .common import Probe, Platform
from .descriptor import SubTarget, OsInfo, OsType, OsMode


class SSHProbe(Probe):
    def __init__(self, ssh: SSHManager):
        super().__init__()
        self._ssh = ssh

    def _run_to_line(self, *args):
        try:
            return next(self._ssh.run(*args)[0]).strip()
        except subprocess.CalledProcessError:
            return None

    def _run_to_lines(self, *args):
        try:
            stdout, _ = self._ssh.run(*args)
            return (line.strip() for line in stdout)
        except subprocess.CalledProcessError:
            return None

    def get_mac(self):
        return self._run_to_line('cat', '/sys/class/net/eth0/address')


class SSHProbeBOS(SSHProbe):
    def __init__(self, ssh: SSHManager):
        super().__init__(ssh)

    def get_version(self) -> str:
        bos_version = self._run_to_line('cat', '/etc/bos_version')
        if bos_version:
            return bos_version
        # fallback for old firmwares
        installed = self.get_opkg_installed() or []
        fw_version = next((fw for fw in installed if fw.startswith('firmware')), None)
        m = fw_version and re.findall(r' - (.*)', fw_version)
        return m and m[0]

    def get_board_name(self) -> str:
        return self._run_to_line('cat', '/tmp/sysinfo/board_name')

    def detect_mode(self) -> OsMode or None:
        mode = self._run_to_line('cat', '/etc/bos_mode')
        if mode is not None:
            return OsMode[mode]
        # fallback for old releases
        for line in self._run_to_lines('mount') or []:
            if line.startswith('/dev/ubi0_2 on /overlay'):
                return OsMode.nand
            elif line.startswith('/dev/mmcblk0p2 on /overlay'):
                return OsMode.sd
        else:
            return OsMode.recovery

    def get_opkg_installed(self):
        return self._run_to_lines('opkg', 'list-installed')

    def get_net_proto(self):
        return self._run_to_line('uci', 'get', 'network.lan.proto')

    def get_net_hostname(self):
        return self._run_to_line(
            'uci', 'get', 'network.lan.hostname'
        ) or self._run_to_line('cat', '/proc/sys/kernel/hostname')

    def get_net_ipaddr(self):
        return self._run_to_line('uci', 'get', 'network.lan.ipaddr')

    def get_net_mask(self):
        return self._run_to_line('uci', 'get', 'network.lan.netmask')

    def get_net_gateway(self):
        return self._run_to_line('uci', 'get', 'network.lan.gateway')

    def get_net_dns(self):
        return self._run_to_line('uci', 'get', 'network.lan.dns')

    def get_bosminer_config(self):
        import json

        config = self._run_to_line('bosminer', 'config', '--data')
        return config and json.loads(config)

    def get_bosminer_model(self) -> str or None:
        config = self.get_bosminer_config()
        if not config:
            return None
        config_data = config.get('data')
        config_format = config_data and config_data.get('format')
        return config_format and config_format.get('model')


class SSHProbeBitmain(SSHProbe):
    def __init__(self, ssh: SSHManager):
        super().__init__(ssh)

    def get_compile_time(self):
        return self._run_to_lines('cat', '/usr/bin/compile_time')


def _create_platform_from_ssh_bos(ssh: SSHManager) -> Platform or None:
    def create_am1_s9():
        from .am1_s9 import PlatformAm1

        return PlatformAm1(OsInfo(OsType.BOS, mode, bos_version), probe)

    def create_am2_x17():
        from .am2_x17 import PlatformAm2
        from .descriptor import X17Variant

        if bosminer_model and len(bosminer_model) >= 2:
            variant = X17Variant(bosminer_model[1])
        else:
            variant = X17Variant.X17
        return PlatformAm2(
            SubTarget.x17, variant, OsInfo(OsType.BOS, mode, bos_version), probe
        )

    probe = SSHProbeBOS(ssh)
    bos_version = probe.get_version()
    if not bos_version:
        return None
    mode = probe.detect_mode()
    if not mode:
        return None
    board_name = probe.get_board_name()
    if not board_name:
        return None
    bosminer_model = probe.get_bosminer_model()
    bosminer_model = bosminer_model and bosminer_model.split(maxsplit=1)
    constructor = {'am1-s9': create_am1_s9, 'am2-s17': create_am2_x17}.get(board_name)
    return constructor and constructor()


def _create_platform_from_ssh_bitmain(ssh: SSHManager) -> Platform or None:
    def create_am1_s9():
        from .am1_s9 import PlatformAm1

        return PlatformAm1(OsInfo(os_type, OsMode.nand, os_version), probe)

    def create_am2_x17():
        from .am2_x17 import PlatformAm2
        from .descriptor import X17Variant

        if len(miner_type) >= 3 and miner_type[2] == 'Pro':
            variant = miner_type[1] + ' Pro'
        else:
            variant = miner_type[1]
        variant = X17Variant(variant)
        return variant and PlatformAm2(
            SubTarget.x17, variant, OsInfo(os_type, OsMode.nand, os_version), probe
        )

    probe = SSHProbeBitmain(ssh)
    compile_time = probe.get_compile_time()
    if not compile_time:
        return None
    os_version, miner_type, _logic_version = (list(compile_time) + [None])[:3]

    # detect Vnish firmware which is based on stock firmware
    if ' (vnish ' in miner_type:
        os_type = OsType.Vnish
        m = re.findall(r' \(vnish (.*)\)', miner_type)
        if m:
            # use Vnish version as a OS version
            os_version = m[0]
    else:
        os_type = OsType.stock

    miner_type = miner_type.split()
    if len(miner_type) < 2:
        return None
    if miner_type[0] != 'Antminer':
        return None
    constructor = {
        'S9': create_am1_s9,
        'S9i': create_am1_s9,
        'S9j': create_am1_s9,
        'R4': create_am1_s9,
        'X17': create_am2_x17,
        'S17': create_am2_x17,
        'T17': create_am2_x17,
        'S17+': create_am2_x17,
        'T17+': create_am2_x17,
    }.get(miner_type[1])
    return constructor and constructor()


def get_platform_from_ssh(ssh: SSHManager) -> Platform or None:
    creators = [_create_platform_from_ssh_bos, _create_platform_from_ssh_bitmain]
    for creator in creators:
        platform = creator(ssh)
        if platform:
            return platform
    return None
