//go:build !nosas
// +build !nosas

package olm

// #cgo LDFLAGS: -lolm -lstdc++
// #include <olm/olm.h>
// #include <olm/sas.h>
import "C"

import (
	"crypto/rand"
	"unsafe"
)

// SAS stores an Olm Short Authentication String (SAS) object.
type SAS struct {
	int *C.OlmSAS
	mem []byte
}

// NewBlankSAS initializes an empty SAS object.
func NewBlankSAS() *SAS {
	memory := make([]byte, sasSize())
	return &SAS{
		int: C.olm_sas(unsafe.Pointer(&memory[0])),
		mem: memory,
	}
}

// sasSize is the size of a SAS object in bytes.
func sasSize() uint {
	return uint(C.olm_sas_size())
}

// sasRandomLength is the number of random bytes needed to create an SAS object.
func (sas *SAS) sasRandomLength() uint {
	return uint(C.olm_create_sas_random_length(sas.int))
}

// NewSAS creates a new SAS object.
func NewSAS() *SAS {
	sas := NewBlankSAS()
	random := make([]byte, sas.sasRandomLength()+1)
	_, err := rand.Read(random)
	if err != nil {
		panic(NotEnoughGoRandom)
	}
	r := C.olm_create_sas(
		(*C.OlmSAS)(sas.int),
		unsafe.Pointer(&random[0]),
		C.size_t(len(random)))
	if r == errorVal() {
		panic(sas.lastError())
	} else {
		return sas
	}
}

// clear clears the memory used to back an SAS object.
func (sas *SAS) clear() uint {
	return uint(C.olm_clear_sas(sas.int))
}

// lastError returns the most recent error to happen to an SAS object.
func (sas *SAS) lastError() error {
	return convertError(C.GoString(C.olm_sas_last_error(sas.int)))
}

// pubkeyLength is the size of a public key in bytes.
func (sas *SAS) pubkeyLength() uint {
	return uint(C.olm_sas_pubkey_length((*C.OlmSAS)(sas.int)))
}

// GetPubkey gets the public key for the SAS object.
func (sas *SAS) GetPubkey() []byte {
	pubkey := make([]byte, sas.pubkeyLength())
	r := C.olm_sas_get_pubkey(
		(*C.OlmSAS)(sas.int),
		unsafe.Pointer(&pubkey[0]),
		C.size_t(len(pubkey)))
	if r == errorVal() {
		panic(sas.lastError())
	}
	return pubkey
}

// SetTheirKey sets the public key of the other user.
func (sas *SAS) SetTheirKey(theirKey []byte) error {
	theirKeyCopy := make([]byte, len(theirKey))
	copy(theirKeyCopy, theirKey)
	r := C.olm_sas_set_their_key(
		(*C.OlmSAS)(sas.int),
		unsafe.Pointer(&theirKeyCopy[0]),
		C.size_t(len(theirKeyCopy)))
	if r == errorVal() {
		return sas.lastError()
	}
	return nil
}

// GenerateBytes generates bytes to use for the short authentication string.
func (sas *SAS) GenerateBytes(info []byte, count uint) ([]byte, error) {
	infoCopy := make([]byte, len(info))
	copy(infoCopy, info)
	output := make([]byte, count)
	r := C.olm_sas_generate_bytes(
		(*C.OlmSAS)(sas.int),
		unsafe.Pointer(&infoCopy[0]),
		C.size_t(len(infoCopy)),
		unsafe.Pointer(&output[0]),
		C.size_t(len(output)))
	if r == errorVal() {
		return nil, sas.lastError()
	}
	return output, nil
}

// macLength is the size of a message authentication code generated by olm_sas_calculate_mac.
func (sas *SAS) macLength() uint {
	return uint(C.olm_sas_mac_length((*C.OlmSAS)(sas.int)))
}

// CalculateMAC generates a message authentication code (MAC) based on the shared secret.
func (sas *SAS) CalculateMAC(input []byte, info []byte) ([]byte, error) {
	inputCopy := make([]byte, len(input))
	copy(inputCopy, input)
	infoCopy := make([]byte, len(info))
	copy(infoCopy, info)
	mac := make([]byte, sas.macLength())
	r := C.olm_sas_calculate_mac(
		(*C.OlmSAS)(sas.int),
		unsafe.Pointer(&inputCopy[0]),
		C.size_t(len(inputCopy)),
		unsafe.Pointer(&infoCopy[0]),
		C.size_t(len(infoCopy)),
		unsafe.Pointer(&mac[0]),
		C.size_t(len(mac)))
	if r == errorVal() {
		return nil, sas.lastError()
	}
	return mac, nil
}
