// Copyright 2020 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.

package checkers

import (
	"fmt"
	"regexp"

	gc "gopkg.in/check.v1"
)

// MultiChecker is a deep checker that by default matches for equality.
// But checks can be overriden based on path (either explicit match or regexp)
type MultiChecker struct {
	*gc.CheckerInfo
	checks      map[string]multiCheck
	regexChecks []regexCheck
}

type multiCheck struct {
	checker gc.Checker
	args    []interface{}
}

type regexCheck struct {
	multiCheck
	regex *regexp.Regexp
}

// NewMultiChecker creates a MultiChecker which is a deep checker that by default matches for equality.
// But checks can be overriden based on path (either explicit match or regexp)
func NewMultiChecker() *MultiChecker {
	return &MultiChecker{
		CheckerInfo: &gc.CheckerInfo{Name: "MultiChecker", Params: []string{"obtained", "expected"}},
		checks:      make(map[string]multiCheck),
	}
}

// Add an explict checker by path.
func (checker *MultiChecker) Add(path string, c gc.Checker, args ...interface{}) *MultiChecker {
	checker.checks[path] = multiCheck{
		checker: c,
		args:    args,
	}
	return checker
}

// AddRegex exception which matches path with regex.
func (checker *MultiChecker) AddRegex(pathRegex string, c gc.Checker, args ...interface{}) *MultiChecker {
	checker.regexChecks = append(checker.regexChecks, regexCheck{
		multiCheck: multiCheck{
			checker: c,
			args:    args,
		},
		regex: regexp.MustCompile("^" + pathRegex + "$"),
	})
	return checker
}

// Check for go check Checker interface.
func (checker *MultiChecker) Check(params []interface{}, names []string) (result bool, errStr string) {
	customCheckFunc := func(path string, a1 interface{}, a2 interface{}) (useDefault bool, equal bool, err error) {
		var mc *multiCheck
		if c, ok := checker.checks[path]; ok {
			mc = &c
		} else {
			for _, v := range checker.regexChecks {
				if v.regex.MatchString(path) {
					mc = &v.multiCheck
					break
				}
			}
		}
		if mc == nil {
			return true, false, nil
		}

		params := append([]interface{}{a1}, mc.args...)
		info := mc.checker.Info()
		if len(params) < len(info.Params) {
			return false, false, fmt.Errorf("Wrong number of parameters for %s: want %d, got %d", info.Name, len(info.Params), len(params)+1)
		}
		// Copy since it may be mutated by Check.
		names := append([]string{}, info.Params...)

		// Trim to the expected params len.
		params = params[:len(info.Params)]

		// Perform substitution
		for i, v := range params {
			if v == ExpectedValue {
				params[i] = a2
			}
		}

		result, errStr := mc.checker.Check(params, names)
		if result {
			return false, true, nil
		}
		if path == "" {
			path = "top level"
		}
		return false, false, fmt.Errorf("mismatch at %s: %s", path, errStr)
	}
	if ok, err := DeepEqualWithCustomCheck(params[0], params[1], customCheckFunc); !ok {
		return false, err.Error()
	}
	return true, ""
}

// ExpectedValue if passed to MultiChecker.Add or MultiChecker.AddRegex, will be substituded with the expected value.
var ExpectedValue = &struct{}{}