#!/usr/bin/env python3
# SPDX-FileCopyrightText: Christian Amsüss and the aiocoap contributors
#
# SPDX-License-Identifier: MIT

"""A client suitable for running the OSCORE plug test series against a given
server

See https://github.com/EricssonResearch/OSCOAP for the test suite
description."""

import argparse
import asyncio
import logging
import signal
import functools
from pathlib import Path
import sys

from aiocoap import *
from aiocoap import error
from aiocoap import interfaces
from aiocoap import credentials

# In some nested combinations of unittest and coverage, the usually
# provided-by-default inclusion of local files does not work. Ensuring the
# local plugtest_common *can* be included.
import os.path
sys.path.append(os.path.dirname(__file__))
from plugtest_common import *


class PlugtestClientProgram:
    async def run(self):
        p = argparse.ArgumentParser(description="Client for the OSCORE plug test.")
        p.add_argument("host", help="Hostname of the server")
        p.add_argument("contextdir", help="Directory name where to persist sequence numbers", type=Path)
        p.add_argument("testno", nargs="?", type=int, help="Test number to run (integer part); leave out for interactive mode")
        p.add_argument("--verbose", help="Show aiocoap debug messages", action='store_true')
        opts = p.parse_args()

        self.host = opts.host

        # this also needs to be called explicitly as only the
        # 'logging.warning()'-style functions will call it; creating a
        # sub-logger and logging from there makes the whole logging system not
        # emit the 'WARNING' prefixes that set apart log messages from regular
        # prints and also help the test suite catch warnings and errors
        if opts.verbose:
            logging.basicConfig(level=logging.DEBUG)
        else:
            logging.basicConfig(level=logging.WARNING)

        security_context_a = get_security_context('a', opts.contextdir / "a")
        security_context_c = get_security_context('c', opts.contextdir / "c")

        self.ctx = await Context.create_client_context()
        self.ctx.client_credentials[":ab"] = security_context_a
        self.ctx.client_credentials[":cd"] = security_context_c

        if opts.testno is not None:
            await self.run_test(opts.testno)
        else:
            next_testno = 0
            delta = 1

            while True:
                # Yes this is blocking, but since the tests usually terminate
                # by themselves ... *shrug*
                try:
                    next = input("Enter a test number (empty input runs %s, q to quit): " % next_testno)
                except (KeyboardInterrupt, EOFError):
                    keeprunning = False
                    break
                if next == "q":
                    keeprunning = False
                    break
                if next:
                    try:
                        as_int = int(next)
                    except ValueError:
                        print("That's not a number.")
                        continue
                    else:
                        if as_int - next_testno + delta in (0, 1):
                            # Don't jump around randomly if user jumped around,
                            # but otherwise do something sane
                            delta = as_int - next_testno + delta
                        next_testno = as_int

                print("Running test %s" % next_testno)
                try:
                    await self.run_test(next_testno)
                except error.Error as e:
                    print("Test failed with an exception:", e)
                print()
                next_testno += delta

    ctx = None

    async def run_with_shutdown(self):
        # Having SIGTERM cause a more graceful shutdown (even if it's not
        # asynchronously awaiting the shutdown, which would be impractical
        # since we're likely inside some unintended timeout already) allow for
        # output buffers to be flushed when the unit test program instrumenting
        # it terminates it.
        loop = asyncio.get_running_loop()
        loop.add_signal_handler(signal.SIGTERM, loop.close)

        try:
            await self.run()
        finally:
            if self.ctx is not None:
                await self.ctx.shutdown()

    def use_context(self, which):
        if which is None:
            self.ctx.client_credentials.pop("coap://%s/*" % self.host, None)
        else:
            self.ctx.client_credentials["coap://%s/*" % self.host] = ":" + which

    async def run_test(self, testno):
        self.testno = testno
        testfun = self.__methods[testno]
        try:
            await getattr(self, testfun)()
        except oscore.NotAProtectedMessage as e:
            print("Response carried no Object-Security option, but was: %s %s"%(e.plain_message, e.plain_message.payload))
            raise

    __methods = {}
    def __implements_tests(numbers, __methods=__methods):
        def registerer(method):
            for n in numbers:
                __methods[n] = method.__name__
            return method
        return registerer

    @__implements_tests([0])
    async def test_plain(self):
        request = Message(code=GET, uri='coap://' + self.host + '/oscore/hello/coap')

        self.use_context(None)

        response = await self.ctx.request(request).response

        print("Response:", response)
        additional_verify("Responde had correct code", response.code, CONTENT)
        additional_verify("Responde had correct payload", response.payload, b"Hello World!")
        additional_verify("Options as expected", response.opt, Message(content_format=0).opt)

    @__implements_tests([1, 2])
    async def test_hello12(self):
        if self.testno == 1:
            self.use_context("ab")
        else:
            self.use_context("cd")

        request = Message(code=GET, uri='coap://' + self.host+ '/oscore/hello/1')
        expected = {'content_format': 0}
        unprotected_response = await self.ctx.request(request).response

        print("Unprotected response:", unprotected_response)
        additional_verify("Code as expected", unprotected_response.code, CONTENT)
        additional_verify("Options as expected", unprotected_response.opt, Message(**expected).opt)
        additional_verify("Payload as expected", unprotected_response.payload, b"Hello World!")

    @__implements_tests([3])
    async def test_hellotest3(self):
        self.use_context("ab")

        request = Message(code=GET, uri='coap://' + self.host+ '/oscore/hello/2?first=1')
        expected = {'content_format': 0, 'etag': b'\x2b'}
        unprotected_response = await self.ctx.request(request).response

        print("Unprotected response:", unprotected_response)
        additional_verify("Code as expected", unprotected_response.code, CONTENT)
        additional_verify("Options as expected", unprotected_response.opt, Message(**expected).opt)
        additional_verify("Payload as expected", unprotected_response.payload, b"Hello World!")

    @__implements_tests([4])
    async def test_hellotest4(self):
        self.use_context("ab")

        request = Message(code=GET, uri='coap://' + self.host+ '/oscore/hello/3', accept=0)
        expected = {'content_format': 0, 'max_age': 5}
        unprotected_response = await self.ctx.request(request).response

        print("Unprotected response:", unprotected_response)
        additional_verify("Code as expected", unprotected_response.code, CONTENT)
        additional_verify("Options as expected", unprotected_response.opt, Message(**expected).opt)
        additional_verify("Payload as expected", unprotected_response.payload, b"Hello World!")

    @__implements_tests([5])
    async def test_nonobservable(self):
        self.use_context("ab")

        request = Message(code=GET, uri='coap://' + self.host + '/oscore/hello/1', observe=0)

        request = self.ctx.request(request)

        unprotected_response = await request.response

        print("Unprotected response:", unprotected_response)
        additional_verify("Code as expected", unprotected_response.code, CONTENT)
        additional_verify("Observe option is absent", unprotected_response.opt.observe, None)

        async for o in request.observation:
            print("Expectation failed: Observe events coming in.")

    @__implements_tests([6, 7])
    async def test_observable(self):
        self.use_context("ab")

        request = Message(code=GET, uri='coap://' + self.host + '/oscore/observe%s' % (self.testno - 5), observe=0)

        request = self.ctx.request(request)

        unprotected_response = await request.response

        print("Unprotected response:", unprotected_response)
        additional_verify("Code as expected", unprotected_response.code, CONTENT)
        additional_verify("Observe option present", unprotected_response.opt.observe is not None, True)

        payloads = [unprotected_response.payload]

        async for o in request.observation:
            # FIXME: where's the 'two' stuck?
            payloads.append(o.payload)
            print("Verify: Received message", o, o.payload)
            if len(payloads) == 2 and self.testno == 7:
                # FIXME: Not yet ready to send actual cancellations
                break

        expected_payloads = [b'one', b'two']
        if self.testno == 6:
            expected_payloads.append(b'Terminate Observe')
        additional_verify("Server gave the correct responses", payloads, expected_payloads)

    @__implements_tests([8])
    async def test_post(self):
        self.use_context("ab")

        request = Message(code=POST, uri='coap://' + self.host+ '/oscore/hello/6', payload=b"\x4a", content_format=0)
        unprotected_response = await self.ctx.request(request).response

        print("Unprotected response:", unprotected_response)
        additional_verify("Code as expected", CHANGED, unprotected_response.code)
        additional_verify("Options as expected", unprotected_response.opt, Message(content_format=0).opt)
        additional_verify("Payload as expected", unprotected_response.payload, b"\x4a")

    @__implements_tests([9])
    async def test_put_match(self):
        self.use_context("ab")

        request = Message(code=PUT, uri='coap://' + self.host+ '/oscore/hello/7', payload=b"\x7a", content_format=0, if_match=[b"\x7b"])
        unprotected_response = await self.ctx.request(request).response

        print("Unprotected response:", unprotected_response)
        additional_verify("Code as expected", CHANGED, unprotected_response.code)
        additional_verify("Options empty as expected", Message().opt, unprotected_response.opt)
        additional_verify("Payload absent as expected", unprotected_response.payload, b"")

    @__implements_tests([10])
    async def test_put_nonmatch(self):
        self.use_context("ab")

        request = Message(code=PUT, uri='coap://' + self.host+ '/oscore/hello/7', payload=b"\x8a", content_format=0, if_none_match=True)
        unprotected_response = await self.ctx.request(request).response

        print("Unprotected response:", unprotected_response)
        additional_verify("Code as expected", PRECONDITION_FAILED, unprotected_response.code)
        additional_verify("Options empty as expected", Message().opt, unprotected_response.opt)

    @__implements_tests([11])
    async def test_delete(self):
        self.use_context("ab")

        request = Message(code=DELETE, uri='coap://' + self.host+ '/oscore/test')
        unprotected_response = await self.ctx.request(request).response

        print("Unprotected response:", unprotected_response)
        additional_verify("Code as expected", DELETED, unprotected_response.code)
        additional_verify("Options empty as expected", Message().opt, unprotected_response.opt)

    @__implements_tests([12, 13, 14])
    async def test_oscoreerror_server_reports_error(self):
        self.use_context("ab")
        secctx = self.ctx.client_credentials[":ab"]

        if self.testno == 12:
            expected_code = UNAUTHORIZED
            expected_error = oscore.NotAProtectedMessage
            # FIXME: frobbing the sender_id breaks sequence numbers...
            frobnicate_field = 'sender_id'
        elif self.testno == 13:
            expected_code = BAD_REQUEST
            expected_error = oscore.NotAProtectedMessage
            frobnicate_field = 'sender_key'
        elif self.testno == 14:
            expected_code = None
            expected_error = oscore.ProtectionInvalid
            frobnicate_field = 'recipient_key'

        unfrobnicated = getattr(secctx, frobnicate_field)
        if unfrobnicated == b'':
            setattr(secctx, frobnicate_field, b'spam')
        else:
            setattr(secctx, frobnicate_field, bytes((255 - x) for x in unfrobnicated))

        request = Message(code=GET, uri='coap://' + self.host + '/oscore/hello/1')

        try:
            unprotected_response = await self.ctx.request(request).response
        except expected_error as e:
            if expected_code is not None:
                if e.plain_message.code == expected_code:
                    print("Check passed: The server responded with unencrypted %s."%(expected_code))
                else:
                    print("Failed: Server responded with something unencrypted, but not the expected code %s: %s"%(expected_code, e.plain_message))
            else:
                print("Check passed: The validation failed. (%s)" % e)
        else:
            print("Failed: The validation passed.")
            print("Unprotected response:", unprotected_response)
        finally:
            setattr(secctx, frobnicate_field, unfrobnicated)
            # With a frobnicated sender_id, the stored context could not be
            # loaded for later use; making sure it's stored properly again.
            secctx._store()

    @__implements_tests([15])
    async def test_replay(self):
        self.use_context("ab")

        request = Message(code=GET, uri='coap://' + self.host + '/oscore/hello/1')

        unprotected_response = await self.ctx.request(request).response # make this _nonraising as soon as there's a proper context backend
        if unprotected_response.code != CONTENT:
            print("Failed: Request did not even pass before replay (%s)"%unprotected_response)
            return

        secctx = self.ctx.client_credentials[":ab"]
        secctx.sender_sequence_number -= 1

        try:
            unprotected_response = await self.ctx.request(request).response
        except oscore.NotAProtectedMessage as e:
            if e.plain_message.code == UNAUTHORIZED:
                print("Check passed: The server responded with unencrypted replay error.")
            else:
                print("Failed: Server responded with something unencrypted, but not the expected code %s: %s"%(e.plain_message.code, e.plain_message))
        else:
            print("Failed: the validation passed.")
            print("Unprotected response:", unprotected_response)

    @__implements_tests([16])
    async def test_nonoscore_server(self):
        self.use_context("ab")

        request = Message(code=GET, uri='coap://' + self.host+ '/oscore/hello/coap')

        try:
            response = await self.ctx.request(request).response
        except oscore.NotAProtectedMessage as e:
            if e.plain_message.code == BAD_OPTION:
                print("Passed: Server reported bad option.")
            elif e.plain_message.code == METHOD_NOT_ALLOWED:
                print("Dubious: Server reported Method Not Allowed.")
            else:
                print("Failed: Server reported %s" % e.plain_message.code)
        else:
            print("Failed: The server accepted an OSCORE message")

    @__implements_tests([17])
    async def test_unauthorized(self):
        request = Message(code=GET, uri='coap://' + self.host + '/oscore/hello/1')

        self.use_context(None)

        response = await self.ctx.request(request).response

        print("Response:", response, response.payload)
        additional_verify("Responde had correct code", response.code, UNAUTHORIZED)

#     # unofficial blockwise tests start here
# 
#     @__implements_tests([16, 17])
#     async def test_block2(self):
#         #request = Message(code=GET, uri='coap://' + self.host + '/oscore/block/' + {16: 'outer', 17: 'inner'}[self.testno])
#         request = Message(code=GET, uri='coap://' + self.host + '/oscore/LargeResource')
# 
#         expected = {'content_format': 0}
#         unprotected_response = await self.ctx.request(request, handle_blockwise=True).response
#         if self.testno == 17:
#             # the library should probably strip that
#             expected['block2'] = optiontypes.BlockOption.BlockwiseTuple(block_number=1, more=False, size_exponent=6)
# 
#         print("Unprotected response:", unprotected_response)
#         additional_verify("Code as expected", unprotected_response.code, CONTENT)
#         additional_verify("Options as expected", unprotected_response.opt, Message(**expected).opt)

if __name__ == "__main__":
    asyncio.run(PlugtestClientProgram().run_with_shutdown())
