package structr

import (
	"fmt"
	"reflect"
	"slices"
	"strconv"
	"testing"
	"unsafe"

	"codeberg.org/gruf/go-mangler"
	"github.com/modern-go/reflect2"
)

type test[T any] struct {
	// cache initialization config.
	indices []IndexConfig
	copyfn  func(*T) *T

	// values to cache.
	values []*T

	// equality check for values.
	equalfn func(*T, *T) bool
}

type structA struct {
	Field1 string
	Field2 int
	Field3 float32
}

func (a *structA) String() string {
	if a == nil {
		return "<nil>"
	}
	str := fmt.Sprintf("%+v", *a)
	str = strconv.Quote(str)
	return str[1 : len(str)-1]
}

var structAIndices = []IndexConfig{
	{
		Fields:    "Field1",
		Multiple:  false,
		AllowZero: false,
	},
	{
		Fields:    "Field2",
		Multiple:  true,
		AllowZero: true,
	},
	{
		Fields:    "Field3",
		Multiple:  true,
		AllowZero: true,
	},
	{
		Fields:    "Field2,Field3",
		Multiple:  false,
		AllowZero: true,
	},
}

var structAValues = []*structA{
	{Field1: "zero-zero", Field2: 0, Field3: 0},
	{Field1: "one-zero", Field2: 1, Field3: 0},
	{Field1: "zero-one", Field2: 0, Field3: 1},
	{Field1: "69-420", Field2: 69, Field3: 420},
	{Field1: "420-69", Field2: 420, Field3: 69},
}

var testStructA = test[structA]{
	indices: structAIndices,
	copyfn: func(in *structA) *structA {
		out := new(structA)
		*out = *in
		return out
	},
	values: structAValues,
	equalfn: func(a, b *structA) bool {
		return a.Field1 == b.Field1 &&
			a.Field2 == b.Field2 &&
			a.Field3 == b.Field3
	},
}

type structB struct {
	Field1 *string
	Field2 *int
	Field3 *float64
}

func (b *structB) String() string {
	if b == nil {
		return "<nil>"
	}
	str := fmt.Sprintf("%+v", *b)
	str = strconv.Quote(str)
	return str[1 : len(str)-1]
}

var structBIndices = []IndexConfig{
	{
		Fields:    "Field1",
		Multiple:  false,
		AllowZero: false,
	},
	{
		Fields:    "Field2",
		Multiple:  true,
		AllowZero: true,
	},
	{
		Fields:    "Field3",
		Multiple:  true,
		AllowZero: true,
	},
	{
		Fields:    "Field2,Field3",
		Multiple:  false,
		AllowZero: true,
	},
}

var structBValues = []*structB{
	{Field1: ptr("nil-nil"), Field2: nil, Field3: nil},
	{Field1: ptr("one-nil"), Field2: ptr(1), Field3: nil},
	{Field1: ptr("nil-one"), Field2: nil, Field3: ptr(1.0)},
	{Field1: ptr("69-420"), Field2: ptr(69), Field3: ptr(420.0)},
	{Field1: ptr("420-69"), Field2: ptr(420), Field3: ptr(69.0)},
}

var testStructB = test[structB]{
	indices: structBIndices,
	copyfn: func(in *structB) *structB {
		out := new(structB)
		*out = *in
		return out
	},
	values: structBValues,
	equalfn: func(a, b *structB) bool {
		return ptrsequal(a.Field1, b.Field1) &&
			ptrsequal(a.Field2, b.Field2) &&
			ptrsequal(a.Field3, b.Field3)
	},
}

type channel (chan struct{})

type function (func())

func init() {
	ch := make(channel)
	fn := function(func() {})

	rt := reflect.TypeOf(ch)
	mangler.Register(rt, func(buf []byte, ptr unsafe.Pointer) []byte {
		var ch channel
		if ptr := (*channel)(ptr); ptr != nil {
			ch = *ptr
		}
		return fmt.Appendf(buf, "%#v", ch)
	})

	rt = reflect.TypeOf(&ch)
	mangler.Register(rt, func(buf []byte, ptr unsafe.Pointer) []byte {
		var ch *channel
		if ptr := (**channel)(ptr); ptr != nil {
			ch = *ptr
		}
		return fmt.Appendf(buf, "%#v", ch)
	})

	rt = reflect.TypeOf(fn)
	mangler.Register(rt, func(buf []byte, ptr unsafe.Pointer) []byte {
		var fn function
		if ptr := (*function)(ptr); ptr != nil {
			fn = *ptr
		}
		return fmt.Appendf(buf, "%#v", fn)
	})

	rt = reflect.TypeOf(&fn)
	mangler.Register(rt, func(buf []byte, ptr unsafe.Pointer) []byte {
		var fn *function
		if ptr := (**function)(ptr); ptr != nil {
			fn = *ptr
		}
		return fmt.Appendf(buf, "%#v", fn)
	})
}

type structC struct {
	Field1 string
	Field2 *int
	Field3 *struct {
		Field1 string
		Field2 *int
	}
	Field4  channel
	Field5  *channel
	Field6  function
	Field7  *function
	Field8  []string
	Field9  [2]string
	Field10 [1]string
	Field11 struct {
		Field1 *string
		Field2 string
	}
	Field12 *struct {
		Field1 *string
		Field2 string
	}
	Field13 struct{ onlyone *string }
	Field14 [1]*string
}

func (c *structC) String() string {
	if c == nil {
		return "<nil>"
	}
	str := fmt.Sprintf("%+v", *c)
	str = strconv.Quote(str)
	return str[1 : len(str)-1]
}

var structCIndices = []IndexConfig{
	{
		Fields:    "Field1",
		Multiple:  false,
		AllowZero: false,
	},
	{
		Fields:    "Field2",
		Multiple:  true,
		AllowZero: true,
	},
	{
		Fields:    "Field3.Field1",
		Multiple:  true,
		AllowZero: true,
	},
	{
		Fields:    "Field3.Field2",
		Multiple:  true,
		AllowZero: true,
	},
	{
		Fields:    "Field4,Field5",
		Multiple:  true,
		AllowZero: false,
	},
	{
		Fields:    "Field6,Field7",
		Multiple:  true,
		AllowZero: false,
	},
	{
		Fields:    "Field4,Field7",
		Multiple:  true,
		AllowZero: false,
	},
	{
		Fields:    "Field5,Field6",
		Multiple:  true,
		AllowZero: false,
	},
	{
		Fields:    "Field8",
		Multiple:  true,
		AllowZero: false,
	},
	{
		Fields:    "Field9,Field10",
		Multiple:  true,
		AllowZero: false,
	},
	{
		Fields:    "Field11,Field12",
		Multiple:  true,
		AllowZero: false,
	},
	{
		Fields:    "Field13",
		Multiple:  true,
		AllowZero: false,
	},
	{
		Fields:    "Field14",
		Multiple:  true,
		AllowZero: false,
	},
}

var structCValues = []*structC{
	{Field1: "zero-zero", Field2: nil, Field3: &struct {
		Field1 string
		Field2 *int
	}{
		Field1: "zero-zero",
		Field2: nil,
	}},
	{Field1: "one-zero", Field2: ptr(1), Field3: &struct {
		Field1 string
		Field2 *int
	}{
		Field1: "one-zero",
		Field2: ptr(1),
	}},
	{Field1: "zero-one", Field2: nil, Field3: &struct {
		Field1 string
		Field2 *int
	}{
		Field1: "zero-one",
		Field2: nil,
	}},
	{Field1: "69-420", Field2: ptr(69), Field3: &struct {
		Field1 string
		Field2 *int
	}{
		Field1: "69-420",
		Field2: ptr(69),
	}},
	{Field1: "420-69", Field2: ptr(420), Field3: &struct {
		Field1 string
		Field2 *int
	}{
		Field1: "420-69",
		Field2: ptr(420),
	}},
	{Field1: "empty-nil", Field2: nil, Field3: nil},
	{
		Field1: "channel-set",
		Field4: make(channel),
		Field5: func() *channel { ch := make(channel); return &ch }(),
	},
	{
		Field1: "function-set",
		Field6: func() {},
		Field7: func() *function { fn := function(func() {}); return &fn }(),
	},
	{
		Field1: "channel+function-set",
		Field4: make(channel),
		Field5: func() *channel { ch := make(channel); return &ch }(),
		Field6: func() {},
		Field7: func() *function { fn := function(func() {}); return &fn }(),
	},
	{
		Field1: "[]string-set",
		Field8: []string{"hello", "world"},
	},
	{
		Field1: "[2]string-set",
		Field9: [2]string{"hello", "world"},
	},
	{
		Field1:  "[1]string-set",
		Field10: [1]string{"hello world"},
	},
	{
		Field1:  "field11-set",
		Field10: [1]string{"hello world"},
		Field11: struct {
			Field1 *string
			Field2 string
		}{
			Field1: new(string),
			Field2: "world",
		},
	},
	{
		Field1: "field12-set",
		Field12: &struct {
			Field1 *string
			Field2 string
		}{
			Field1: new(string),
			Field2: "world",
		},
	},
	{
		Field1:  "field13-set",
		Field13: struct{ onlyone *string }{new(string)},
	},
	{
		Field1:  "field13-with-value",
		Field13: struct{ onlyone *string }{ptr("hello")},
	},
	{
		Field1:  "field14-set",
		Field14: [1]*string{new(string)},
	},
	{
		Field1:  "field14-with-value",
		Field14: [1]*string{ptr("hello")},
	},
}

var testStructC = test[structC]{
	indices: structCIndices,
	copyfn: func(in *structC) *structC {
		out := new(structC)
		*out = *in
		if in.Field3 != nil {
			out.Field3 = new(struct {
				Field1 string
				Field2 *int
			})
			*out.Field3 = *in.Field3
		}
		if in.Field13.onlyone != nil {
			out.Field13.onlyone = new(string)
			*out.Field13.onlyone = *in.Field13.onlyone
		}
		return out
	},
	values: structCValues,
	equalfn: func(a, b *structC) bool {
		if a.Field1 != b.Field1 {
			return false
		}
		if !ptrsequal(a.Field2, b.Field2) {
			return false
		}
		if a.Field3 == nil && b.Field3 != nil {
			return false
		} else if a.Field3 != nil && b.Field3 == nil {
			return false
		} else if a.Field3 != nil && b.Field3 != nil {
			if a.Field3.Field1 != b.Field3.Field1 {
				return false
			}
			if !ptrsequal(a.Field3.Field2, b.Field3.Field2) {
				return false
			}
		}
		if a.Field4 != b.Field4 {
			return false
		}
		if !ptrsequal(a.Field5, b.Field5) {
			return false
		}
		if !slices.Equal(a.Field8, b.Field8) {
			return false
		}
		if !slices.Equal(a.Field9[:], b.Field9[:]) {
			return false
		}
		if a.Field10[0] != b.Field10[0] {
			return false
		}
		if !ptrsequal(a.Field11.Field1, b.Field11.Field1) {
			return false
		}
		if a.Field11.Field2 != b.Field11.Field2 {
			return false
		}
		if a.Field12 == nil && b.Field12 != nil {
			return false
		} else if a.Field12 != nil && b.Field12 == nil {
			return false
		} else if a.Field12 != nil && b.Field12 != nil {
			if !ptrsequal(a.Field12.Field1, b.Field12.Field1) {
				return false
			}
			if a.Field12.Field2 != b.Field12.Field2 {
				return false
			}
		}
		return ptrsequal(a.Field13.onlyone, b.Field13.onlyone)
	},
}

// String returns quoted form of serialized
// cache key, useful for debug / test output.
func (k Key) String() string {
	return strconv.Quote(k.key)
}

// indexkey extracts the index key interface parts for index.
func indexkey[T any](index *Index, value *T) ([]any, bool) {
	parts := extract_fields_ifaces(unsafe.Pointer(&value), index.fields)
	return parts, (parts != nil)
}

func extract_fields_ifaces(ptr unsafe.Pointer, fields []struct_field) []any {
	// Prepare slice of field ifaces.
	ifaces := make([]any, len(fields))
	for i, field := range fields {

		// loop scope.
		fptr := ptr

		for _, offset := range field.offsets {
			// Dereference any ptrs to offset.
			fptr = deref(fptr, offset.derefs)

			if fptr == nil {
				// Use zero value.
				fptr = field.zero
				break
			}

			// Jump forward by offset to next ptr.
			fptr = unsafe.Pointer(uintptr(fptr) +
				offset.offset)
		}

		// Get field type as reflect2 for repacking.
		type2 := reflect2.Type2(field.rtype)

		if fptr == nil {
			// No value found, use zero.
			ifaces[i] = type2.New()
		} else {
			// Repack value data ptr as empty interface.
			ifaces[i] = type2.UnsafeIndirect(fptr)
		}
	}

	return ifaces
}

func ptr[T any](t T) *T { return &t }

func ptrsequal[T comparable](p1, p2 *T) bool {
	switch {
	case p1 == nil:
		return (p2 == nil)
	case p2 == nil:
		return false
	default:
		return (*p1 == *p2)
	}
}

func catchpanic(t *testing.T, do func(), expect any) {
	defer func() {
		r := recover()
		if r != expect {
			t.Fatalf("expected panic with %v, got %v", expect, r)
		}
	}()
	do()
}
