#include "stdafx.h"
#include "NameSet.h"
#include "Engine.h"
#include "Exception.h"
#include "Core/StrBuf.h"
#include "Core/Str.h"
#include <limits>

namespace storm {

	NameOverloads::NameOverloads() :
		items(new (this) Array<Named *>()),
		templates(null) {}

	Bool NameOverloads::empty() const {
		return items->empty()
			&& (!templates || templates->empty());
	}

	Nat NameOverloads::count() const {
		return items->count();
	}

	Named *NameOverloads::operator [](Nat id) const {
		return items->at(id);
	}

	Named *NameOverloads::at(Nat id) const {
		return items->at(id);
	}

	static bool equals(Array<Value> *a, Array<Value> *b) {
		if (a->count() != b->count())
			return false;

		for (Nat i = 0; i < a->count(); i++)
			if (a->at(i) != b->at(i))
				return false;

		return true;
	}

	static bool equals(Array<Value> *a, Array<Value> *b, ReplaceContext *ctx) {
		if (!ctx)
			return equals(a, b);

		if (a->count() != b->count())
			return false;

		for (Nat i = 0; i < a->count(); i++)
			if (ctx->normalize(a->at(i)) != ctx->normalize(b->at(i)))
				return false;

		return true;
	}

	MAYBE(Named *) NameOverloads::has(Named *item) {
		for (Nat i = 0; i < items->count(); i++) {
			if (storm::equals(item->params, items->at(i)->params))
				return items->at(i);
		}
		return null;
	}

	void NameOverloads::add(Named *item) {
		for (Nat i = 0; i < items->count(); i++) {
			Named *curr = items->at(i);

			// If we try to re-insert the same object, we don't complain. This may happen when a
			// template "helps" the implementation to find a suitable overload (e.g. modifying a
			// parameter) which results in a duplicate. As we look for object identity here, we only
			// exclude cases where this is intentional (i.e. as it requires the template to remember
			// previously created objects).
			if (item == curr)
				return;

			if (storm::equals(item->params, curr->params)) {
				// If the item we found says that it is OK to replace it, then just replace it
				// without asking questions.
				if (curr->flags & namedAllowReplace) {
					items->at(i) = item;
					item->automaticReplace(curr);
					return;
				} else {
					throw new (this) TypedefError(
						item->pos,
						TO_S(engine(), item << S(" is already defined at:\n@") << items->at(i)->pos << S(": here")));
				}
			}
		}

		items->push(item);
	}

	void NameOverloads::add(Template *item) {
		// There is no way to validate templates at this stage.
		if (!templates)
			templates = new (this) Array<Template *>();
		templates->push(item);
	}

	Bool NameOverloads::remove(Named *item) {
		for (Nat i = 0; i < items->count(); i++) {
			if (items->at(i) == item) {
				items->remove(i);
				return true;
			}
		}

		return false;
	}

	Bool NameOverloads::remove(Template *item) {
		if (!templates)
			return false;

		for (Nat i = 0; i < templates->count(); i++) {
			if (templates->at(i) == item) {
				templates->remove(i);
				return true;
			}
		}

		return false;
	}

	void NameOverloads::merge(NameOverloads *from) {
		// Validate first.
		for (Nat i = 0; i < from->items->count(); i++) {
			Named *add = from->items->at(i);

			for (Nat j = 0; j < items->count(); j++) {
				if (storm::equals(add->params, items->at(j)->params)) {
					throw new (this) TypedefError(
						add->pos,
						TO_S(engine(), add << S(" is already defined at:\n@") << items->at(j)->pos << S(": here")));
				}
			}
		}

		if (from->templates && from->templates->any()) {
			if (!templates)
				templates = new (this) Array<Template *>();

			for (Nat i = 0; i < from->templates->count(); i++)
				templates->push(from->templates->at(i));
		}

		for (Nat i = 0; i < from->items->count(); i++)
			items->push(from->items->at(i));
	}

	void NameOverloads::diff(NameOverloads *with, NameDiff &callback, ReplaceContext *ctx) {
		Array<Bool> *used = new (this) Array<Bool>(with->items->count(), false);
		for (Nat i = 0; i < items->count(); i++) {
			Bool found = false;
			Named *here = items->at(i);
			for (Nat j = 0; j < with->items->count(); j++) {
				if (used->at(j))
					continue;

				Named *other = with->items->at(j);
				if (storm::equals(here->params, other->params, ctx)) {
					found = true;
					used->at(j) = true;
					callback.changed(here, other);
					break;
				}
			}

			if (!found)
				callback.removed(here);
		}

		for (Nat j = 0; j < with->items->count(); j++) {
			if (!used->at(j))
				callback.added(with->items->at(j));
		}
	}

	void NameOverloads::diffAdded(NameDiff &callback) {
		for (Nat i = 0; i < items->count(); i++)
			callback.added(items->at(i));
	}

	void NameOverloads::diffRemoved(NameDiff &callback) {
		for (Nat i = 0; i < items->count(); i++)
			callback.removed(items->at(i));
	}

	void NameOverloads::diffTemplatesAdded(NameDiff &callback) {
		if (!templates)
			return;

		for (Nat i = 0; i < templates->count(); i++)
			callback.added(templates->at(i));
	}

	void NameOverloads::diffTemplatesRemoved(NameDiff &callback) {
		if (!templates)
			return;

		for (Nat i = 0; i < templates->count(); i++)
			callback.removed(templates->at(i));
	}

	Bool NameOverloads::anyTemplates() const {
		return templates && templates->any();
	}

	Named *NameOverloads::createTemplate(NameSet *owner, SimplePart *part, Scope source) {
		if (!templates)
			return null;

		Named *found = null;
		for (Nat i = 0; i < templates->count(); i++) {
			Named *n = templates->at(i)->generate(part);
			if (found != null && n != null) {
				throw new (this) TypedefError(owner->pos, TO_S(engine(), S("Multiple template matches for: ") << part));
			} else if (n) {
				// Only pick it if it matches.
				if (part->matches(n, source) >= 0) {
					found = n;
				}
			}
		}

		return found;
	}

	void NameOverloads::toS(StrBuf *to) const {
		for (Nat i = 0; i < items->count(); i++)
			*to << items->at(i) << L"\n";
		if (templates && templates->any())
			*to << L"<" << templates->count() << L" templates>\n";
	}


	/**
	 * NameSet.
	 */

	NameSet::NameSet(Str *name) : Named(name) {
		init();
	}

	NameSet::NameSet(Str *name, Array<Value> *params) : Named(name, params) {
		init();
	}

	NameSet::NameSet(SrcPos pos, Str *name) : Named(pos, name) {
		init();
	}

	NameSet::NameSet(SrcPos pos, Str *name, Array<Value> *params) : Named(pos, name, params) {
		init();
	}

	void NameSet::init() {
		loaded = false;
		loading = false;
		sourceDiscarded = false;
		nextAnon = 0;

		if (engine().has(bootTemplates))
			lateInit();
	}

	void NameSet::lateInit() {
		Named::lateInit();

		if (!overloads)
			overloads = new (this) Map<Str *, NameOverloads *>();
	}

	void NameSet::compile() {
		forceLoad();

		for (Iter i = begin(), e = end(); i != e; ++i)
			i.v()->compile();
	}

	void NameSet::discardSource() {
		sourceDiscarded = true;
		for (Iter i = begin(), e = end(); i != e; ++i)
			i.v()->discardSource();
	}

	void NameSet::stopDiscardSource() {
		sourceDiscarded = false;
	}

	void NameSet::watchAdd(Named *notifyTo) {
		if (!notify)
			notify = new (this) WeakSet<Named>();
		notify->put(notifyTo);
	}

	void NameSet::watchRemove(Named *notifyTo) {
		if (!notify)
			return;
		notify->remove(notifyTo);
	}

	void NameSet::notifyAdd(Named *what) {
		if (!notify)
			return;

		WeakSet<Named>::Iter i = notify->iter();
		while (Named *n = i.next()) {
			n->notifyAdded(this, what);
		}
	}

	void NameSet::notifyRemove(Named *what) {
		if (!notify)
			return;

		WeakSet<Named>::Iter i = notify->iter();
		while (Named *n = i.next()) {
			n->notifyRemoved(this, what);
		}
	}

	MAYBE(Named *) NameSet::has(Named *item) const {
		if (!overloads)
			return null;

		if (NameOverloads *o = overloads->get(item->name, null))
			return o->has(item);
		else
			return null;
	}

	void NameSet::add(Named *item) {
		if (!overloads)
			overloads = new (this) Map<Str *, NameOverloads *>();

		overloads->at(item->name)->add(item);
		makeChild(item);
		if (sourceDiscarded)
			item->discardSource();
		notifyAdd(item);
	}

	void NameSet::add(Template *item) {
		if (!overloads)
			overloads = new (this) Map<Str *, NameOverloads *>();

		overloads->at(item->name)->add(item);
	}

	Bool NameSet::remove(Named *item) {
		if (!overloads)
			return false;

		NameOverloads *o = overloads->at(item->name);
		Bool ok = o->remove(item);
		if (ok)
			notifyRemove(item);
		if (o->empty())
			overloads->remove(item->name);
		return ok;
	}

	Bool NameSet::remove(Template *item) {
		if (!overloads)
			return false;

		NameOverloads *o = overloads->at(item->name);
		Bool ok = o->remove(item);
		if (o->empty())
			overloads->remove(item->name);
		return ok;
	}

	Str *NameSet::anonName() {
		StrBuf *buf = new (this) StrBuf();
		*buf << S("@ ") << (nextAnon++);
		return buf->toS();
	}

	Array<Named *> *NameSet::content() {
		Array<Named *> *r = new (this) Array<Named *>();
		for (Overloads::Iter at = overloads->begin(); at != overloads->end(); ++at) {
			NameOverloads &o = *at.v();
			for (Nat i = 0; i < o.count(); i++)
				r->push(o[i]);
		}
		return r;
	}

	Array<NameOverloads *> *NameSet::templateOverloads() {
		Array<NameOverloads *> *result = new (this) Array<NameOverloads *>();
		templateOverloads(result);
		return result;
	}

	void NameSet::templateOverloads(Array<NameOverloads *> *result) {
		if (!overloads)
			return;

		for (Overloads::Iter i = overloads->begin(), end = overloads->end(); i != end; ++i) {
			NameOverloads *here = i.v();
			if (here->anyTemplates())
				result->push(here);

			for (Nat i = 0; i < here->count(); i++) {
				NameSet *r = as<NameSet>(here->at(i));
				if (r)
					r->templateOverloads(result);
			}
		}
	}

	MAYBE(NameOverloads *) NameSet::allOverloads(Str *name) {
		Overloads::Iter found = overloads->find(name);
		if (found == overloads->end())
			return null;
		else
			return found.v();
	}

	void NameSet::forceLoad() {
		if (loaded)
			return;

		if (loading) {
			// This happens quite a lot...
			// WARNING(L"Recursive loading attempted for " << name);
			return;
		}

		loading = true;
		try {
			if (loadAll())
				loaded = true;
		} catch (...) {
			loading = false;
			throw;
		}
		loading = false;
	}

	Named *NameSet::find(SimplePart *part, Scope source) {
		if (Named *found = tryFind(part, source))
			return found;

		if (loaded)
			return null;

		if (!loadName(part))
			forceLoad();

		return tryFind(part, source);
	}

	Named *NameSet::tryFind(SimplePart *part, Scope source) {
		if (!overloads)
			return null;

		Overloads::Iter i = overloads->find(part->name);
		if (i == overloads->end())
			return null;

		return tryFind(part, source, i.v());
	}

	Named *NameSet::tryFind(SimplePart *part, Scope source, NameOverloads *from) {
		while (part) {
			if (Named *found = tryFindSingle(part, source, from))
				return found;

			part = part->nextOption();
		}
		return null;
	}

	Named *NameSet::tryFindSingle(SimplePart *part, Scope source, NameOverloads *from) {
		// Note: We do this in two steps. First, we find the best match, and keep track of whether
		// or not there are multiple instances of that or not. If there are multiple instances of
		// the best match, we need to produce an error. To do that, we loop through the candidates a
		// second time (since the error path is generally not critical).
		Named *bestCandidate = null;
		Bool multipleBest = false;
		Int best = std::numeric_limits<Int>::max();

		for (Nat i = 0; i < from->count(); i++) {
			Named *candidate = from->at(i);

			// Ignore ones that are not visible. Note: We delegate this to the part, so that it may
			// modify the default behavior.
			if (!part->visible(candidate, source))
				continue;

			Int badness = part->matches(candidate, source);
			if (badness < 0 || badness > best)
				continue;

			if (badness == best) {
				// Multiple best matches so far. We can't keep track of them without allocating memory.
				multipleBest = true;
			} else {
				multipleBest = false;
				best = badness;
				bestCandidate = candidate;
			}
		}

		// If we have a badness above zero, we might need to create a template to get a better
		// match.
		if (best > 0) {
			if (Named *created = from->createTemplate(this, part, source)) {
				// Note: We always add the created template, even if it is not visible, or the best candidate.
				add(created);

				if (created && part->visible(created, source)) {
					// Check suitability of the newly created template. Note: We more or less expect
					// that the created template is a perfect match. If we get a higher badness,
					// something is probably wrong with the implementation of the template.
					Int badness = part->matches(created, source);
					if (badness >= 0 && badness < best) {
						bestCandidate = created;
						best = badness;
						multipleBest = false;
					}
				}
			}
		}

		if (!multipleBest) {
			// We have an answer!
			return bestCandidate;
		}

		// Error case, we need to find all candidates.
		StrBuf *msg = new (this) StrBuf();
		*msg << S("Multiple possible matches for ") << this << S(", all with badness ") << best << S("\n");

		for (Nat i = 0; i < from->count(); i++) {
			Named *candidate = from->at(i);

			if (!part->visible(candidate, source))
				continue;

			Int badness = part->matches(candidate, source);
			if (badness == best)
				*msg << S("  Could be: ") << candidate->identifier() << S("\n");
		}

		throw new (this) LookupError(msg->toS());
	}

	Bool NameSet::loadName(SimplePart *part) {
		// Default implementation if the derived class does not support lazy-loading.
		// Report some matches may be found using 'loadAll'.
		return false;
	}

	Bool NameSet::loadAll() {
		// Default implementation if the derived class does not support lazy-loading.
		// Report done.
		return true;
	}

	void NameSet::toS(StrBuf *to) const {
		for (Overloads::Iter i = overloads->begin(); i != overloads->end(); ++i) {
			*to << i.v();
		}
	}

	void NameSet::merge(NameSet *from) {
		if (!overloads)
			overloads = new (this) Map<Str *, NameOverloads *>();

		if (sourceDiscarded)
			from->discardSource();

		for (Overloads::Iter i = from->overloads->begin(), end = from->overloads->end(); i != end; ++i) {
			overloads->at(i.k())->merge(i.v());
		}
	}

	void NameSet::diff(NameSet *with, NameDiff &callback, ReplaceContext *ctx) {
		if (!overloads && !with->overloads) {
			// Nothing to do.
		} else if (!overloads) {
			// All entities were added.
			for (Overloads::Iter i = with->overloads->begin(), end = with->overloads->end(); i != end; ++i) {
				i.v()->diffAdded(callback);
				i.v()->diffTemplatesAdded(callback);
			}
		} else if (!with->overloads) {
			// All entities were removed.
			for (Overloads::Iter i = overloads->begin(), end = overloads->end(); i != end; ++i) {
				i.v()->diffRemoved(callback);
				i.v()->diffTemplatesRemoved(callback);
			}
		} else {
			Overloads::Iter end; // The 'end' iterator is always the same.

			for (Overloads::Iter i = overloads->begin(); i != end; ++i) {
				if (NameOverloads *o = with->overloads->get(i.k(), null))
					i.v()->diff(o, callback, ctx);
				else
					i.v()->diffRemoved(callback);
				i.v()->diffTemplatesRemoved(callback);
			}

			for (Overloads::Iter i = with->overloads->begin(); i != end; ++i) {
				if (!overloads->has(i.k()))
					i.v()->diffAdded(callback);
				i.v()->diffTemplatesAdded(callback);
			}
		}
	}

	NameSet::Iter::Iter() : name(), pos(0), nextSet(null) {}

	NameSet::Iter::Iter(Map<Str *, NameOverloads *> *c, NameSet *next) : name(c->begin()), pos(0), nextSet(next) {
		advance();
	}

	Bool NameSet::Iter::operator ==(const Iter &o) const {
		// Either both at end or none.
		if (name == MapIter())
			return o.name == MapIter();

		if (name != o.name)
			return false;

		return pos == o.pos;
	}

	Bool NameSet::Iter::operator !=(const Iter &o) const {
		return !(*this == o);
	}

	Named *NameSet::Iter::v() const {
		return name.v()->at(pos);
	}

	void NameSet::Iter::advance() {
		while (name != MapIter() && pos >= name.v()->count()) {
			++name;
			pos = 0;
		}

		if (nextSet && name == MapIter())
			*this = nextSet->begin();
	}

	NameSet::Iter &NameSet::Iter::operator ++() {
		pos++;
		advance();
		return *this;
	}

	NameSet::Iter NameSet::Iter::operator ++(int) {
		Iter i(*this);
		++*this;
		return i;
	}

	NameSet::Iter NameSet::begin() const {
		if (overloads)
			return Iter(overloads, null);
		else
			return Iter();
	}

	NameSet::Iter NameSet::begin(NameSet *after) const {
		if (overloads)
			return Iter(overloads, after);
		else
			return after->begin();
	}

	NameSet::Iter NameSet::end() const {
		return Iter();
	}

	Array<Named *> *NameSet::findName(Str *name) const {
		Array<Named *> *result = new (this) Array<Named *>();

		if (NameOverloads *f = overloads->get(name, null)) {
			result->reserve(f->count());
			for (Nat i = 0; i < f->count(); i++)
				result->push(f->at(i));
		}

		return result;
	}

	void NameSet::dbg_dump() const {
		PLN(L"Name set:");
		for (Overloads::Iter i = overloads->begin(); i != overloads->end(); ++i) {
			PLN(L" " << i.k() << L":");
			NameOverloads *o = i.v();
			for (Nat i = 0; i < o->count(); i++) {
				PLN(L"  " << o->at(i)->identifier());
			}
		}
	}

}
