from heapq import heappushpop
from podatki import stevke_ucna_mno, opisi_ucna_mno, stevke_testna_mno, opisi_testna_mno


def dist(a, b, k):
    """
    Izracunamo normo ||a - b||_k med vektorjema/vrsticama a in b.
    """
    n = len(a)
    return sum(abs(a[i] - b[i]) ** k for i in range(n)) ** (1/k)


def kNN(stevke_ucna, slike_ucna, stevke_testna, slike_testna, k):
    """
    Za vsako sliko iz testne mnozice najdemo najblizjih k sosedov iz ucne mnozice.
    Napovemo, da slika predstavlja stevko, ki jo napove najvec sosedov.

    Za merjenje razdalje uporabljamo Frobeniusovo normo.

    Na koncu izpisemo uspesnot metode.
    """

    # statistika[i][j] = stevilo primerov, ko smo stevko i proglasili za j
    statistika_napovedi = [[0 for _ in range(10)] for _ in range(10)]
    for resnicna_stevka, opis1 in zip(stevke_testna, slike_testna):
        # racunanje sosedov
        najblizji_sosedje = [(float("-inf"), None) for _ in range(k)]  # najblizjih k sosedov: (-razdalja, napoved)
        for stevka, opis2 in zip(stevke_ucna, slike_ucna):
            nova_dist = dist(opis1, opis2, 2)
            if nova_dist < -najblizji_sosedje[0][0]:
                heappushpop(najblizji_sosedje, (-nova_dist, stevka))
        # glasovanje
        glasovi = [0] * 10
        for _, stevka in najblizji_sosedje:
            glasovi[stevka] += 1
        napovedana_stevka = max(list(range(10)), key=lambda stevka: glasovi[stevka])
        statistika_napovedi[resnicna_stevka][napovedana_stevka] += 1

    for stevka, stat in enumerate(statistika_napovedi):
        primerov = sum(stat)
        natancnost = stat[stevka] / primerov * 100
        print("Napovedi za stevko {}: {} --> {}/{} = {: >6.2f}%".format(stevka, stat, stat[stevka], primerov, natancnost))


k = 3
kNN(stevke_ucna_mno, opisi_ucna_mno, stevke_testna_mno, opisi_testna_mno, k)

