package LinkedHashMap
import HashMap
import Wurstunit

public class LinkedHashMap<K, V>
    protected constant dummy       = new LHMEntry<K, V>(null, null, null)
    private   constant keysEntries = new HashMap<K, LHMEntry<K, V>>()
    protected var      tail        = dummy
    private   var      size        = 0


    construct()
        dummy.next = dummy
        dummy.prev = dummy


    private function addEntry(LHMEntry<K, V> e)
        if tail == dummy
            dummy.next = e
            tail = e
            tail.prev  = dummy
            tail.next  = dummy
        else
            tail.next = e
            e.prev = tail
            e.next = dummy
            tail = e

        dummy.prev = tail

        size++


    /** Saves the given value under the given key. */
    function put(K key, V value)
        if keysEntries.has(key)
            let entry = keysEntries.get(key)
            entry.value = value
        else
            let entry = new LHMEntry<K, V>(this, key, value)
            addEntry(entry)
            keysEntries.put(key, entry)


    function get(K key) returns V
        return keysEntries.get(key).value


    function has(K key) returns bool
        return keysEntries.has(key)


    /** Get the number of mappings in this LinkedHashMap. */
    function size() returns int
        return size


    /**
    Removes the value saved under the given key.

    @returns true on success, and false if there was no such key.
    */
    function remove(K key) returns bool
        var success = false

        if keysEntries.has(key)
            success = true
            remove(keysEntries.get(key))

        return success


    protected function remove(LHMEntry<K, V> e)
        keysEntries.remove(e.key)
        destroy e
        size--


    /** Clears the LinkedHashMap of all entries. */
    function flush()
        keysEntries.flush()
        removeWhen((LHMEntry<K, V> kv) -> true)


    /** Gets an iterator for this LinkedHashMap. */
    function iterator() returns LHMIterator<K, V>
        return new LHMIterator(this, dummy)


    function removeWhen(LinkedHashMapPredicate<K, V> predicate)
        let iter = iterator()
        while iter.hasNext()
            let kv = iter.next()

            if predicate.isTrueFor(kv)
                iter.remove()
        iter.close()
        destroy predicate


    function forEach(LinkedHashMapCallback<K, V> callback)
        let iter = iterator()
        while iter.hasNext()
            let kv = iter.next()

            callback.run(kv)
        iter.close()

        destroy callback


    ondestroy
        flush()
        destroy keysEntries
        destroy dummy


public class LHMEntry<K, V>

    private constant LinkedHashMap<K, V> parent
    constant K key
    V value

    protected thistype prev = null
    protected thistype next = null

    construct(LinkedHashMap<K, V> parent, K key, V value)
        this.parent = parent
        this.key    = key
        this.value  = value

    ondestroy
        if this.parent.tail == this
            this.parent.tail = this.parent.tail.prev

        this.prev.next = this.next
        this.next.prev = this.prev


public class LHMIterator<K, V>
    private LHMEntry<K, V>      dummy
    private LHMEntry<K, V>      current
    private LinkedHashMap<K, V> parentMap

    construct(LinkedHashMap<K, V> parentMap, LHMEntry<K, V> dummy)
        this.dummy     = dummy
        this.current   = dummy
        this.parentMap = parentMap


    /** Remove the current element from the LinkedHashMap. */
    function remove()
        if current == dummy
            error("Tried to delete dummy from LinkedHashMap iterator.")

        parentMap.remove(current)


    function hasNext() returns bool
        return not current.next == dummy


    function next() returns LHMEntry<K, V>
        current = current.next
        return current


    function close()
        destroy this


public interface LinkedHashMapPredicate<K, V>
    function isTrueFor(LHMEntry<K, V> kv) returns bool


public interface LinkedHashMapCallback<K, V>
    function run(LHMEntry<K, V> kv)


class TestClass
    int val

    construct(int val)
        this.val = val


    function toString() returns string
        return "[Test Class] " + val.toString()


@test function checkPutFetch()
    let map = new LinkedHashMap<int, int>()
    map.put(5, 6)

    let six = map.get(5)
    six.assertEquals(6)

    destroy map


@test function checkCanIterate()
    let map = new LinkedHashMap<TestClass, TestClass>()
    let tc1 = new TestClass(1)
    let tc2 = new TestClass(2)
    let tc3 = new TestClass(3)
    let tc4 = new TestClass(4)

    map.put(tc1, tc2)
    map.put(tc3, tc4)

    var sum = 0

    let iter = map.iterator()
    while iter.hasNext()
        let entry = iter.next()
        sum += entry.key  .val
        sum += entry.value.val
    iter.close()

    sum.assertEquals(10)
    destroy map


@test function canHaveMultipleMaps()
    let map  = new LinkedHashMap<int, int>()
    let map2 = new LinkedHashMap<int, int>()

    map. put(1, 2)
    map2.put(3, 4)

    map. get(1).assertEquals(2)
    map2.get(3).assertEquals(4)

    destroy map
    destroy map2


@test function canNestedIterate()
    let map = new LinkedHashMap<int, int>()

    map..put(1, 2)..put(2, 3)..put(3, 4)..put(4, 5)..put(5, 6)

    var outerSum = 0
    for kv in map
        outerSum += kv.key

        var innerSum = 0
        for innerKv in map
            innerSum += innerKv.key

        innerSum.assertEquals(15)
    outerSum.assertEquals(15)


@test function canIterateSameMapTwice()
    let map = new LinkedHashMap<int, int>()

    map.put(1, 2)
    map.put(3, 4)

    var sum = 0

    let iter = map.iterator()
    while iter.hasNext()
        let entry = iter.next()

        sum += entry.key
        sum += entry.value
    iter.close()

    let iter2 = map.iterator()
    while iter2.hasNext()
        let entry = iter2.next()

        sum += entry.key
        sum += entry.value
    iter2.close()

    sum.assertEquals(20)


@test function canHandleDestroyedElements()
    let map = new LinkedHashMap<TestClass, TestClass>()

    let tc1 = new TestClass(1)
    let tc2 = new TestClass(2)
    let tc3 = new TestClass(3)
    let tc4 = new TestClass(4)

    map.put(tc1, tc2)
    map.put(tc3, tc4)

    var sum = 0
    for x in map
        sum += x.key.val
        sum += x.value.val

    sum.assertEquals(10)

    tc1.val = -1

    map.removeWhen((LHMEntry<TestClass, TestClass> kv) -> kv.key.val == -1 or kv.value.val == -1)

    sum = 0
    for x in map
        sum += x.key.val
        sum += x.value.val

    sum.assertEquals(7)

    destroy map

@test function canBecomeEmptyThenIterateAgainLater()
    let map = new LinkedHashMap<int, int>()

    map..put(1, 1)..put(2, 2)..put(3, 3)..put(4, 4)

    map.removeWhen((LHMEntry<int, int> kv) -> kv.key < 3)

    map.size().assertEquals(2)

    map.remove(4)

    map.size().assertEquals(1)

    let iter = map.iterator()
    for kv from iter
        kv.value.assertEquals(3)
        iter.remove()
    iter.close()

    map.size().assertEquals(0)

    map..put(1, 2)..put(2, 3)..put(3, 4)

    var vals = 0
    let braap = map.iterator()
    for kv from braap
        vals += kv.value
    braap.close()

    vals.assertEquals(9)


@test function iterateEmptyMap()
    let m = new LinkedHashMap<int, int>()

    for _ in m
        testFail("Nothing should be iterated here.")

    // Test when there are two LHM's, each with its own dummy.
    let n = new LinkedHashMap<int, int>()

    for _ in n
        testFail("Nothing should be iterated here.")