kurse:efcomputergrafik:knn

This is an old revision of the document!


Ziel

Den $k$-nearest-neighbour Algorithmus in $\mathbb{R}^2$ implementieren.

Wichtige Zutaten
  • Liste mit Distanzen und Klassen, i.e. [[1,0.033131],[0,0.123131],[1,0.123124141],[0,1.2123141]]
  • Sortieren dieser Liste um die $k$ nächsten Nachbarn resp. deren Klasse zu bestimmen:
    • Sortieren von Listen kann mit Python mit sorted gelöst werden. Speziell für unseren Fall ist “Example 3” spannend.
    • Auf Grund der sortierted Liste kann die Mehrheitsmeinung der $k$ nächsten Nachbarn bestimmt werden

Empfehlung: Mindestens zwei Funktionen definieren. Eine zur Berechnugn der Distanz-Klassen-Liste, eine zur Zuweisung der Klasse (0 oder 1).

Wer $k$-nearest-neighbour implementiert hat, kann sich überlegen, wie die untenstehende Daten klassifiziert werden sollen: Die Daten finden sich in einer neuen ZIP-Datein

Click to display ⇲

Click to hide ⇱

knn.py
from gpanel import *
import time
import csv  # um Text-Dateien im CSV-Format zu lesen
import random
 
# CSV-File oefffnen
 
csvfile = open('C:/temp/data.csv', 'r')
 
# CSV-File einlesen.
 
reader = csv.reader(csvfile, delimiter=',',
                    quoting=csv.QUOTE_NONNUMERIC)
 
# CSV-File in Liste umwandeln
 
datalist = list(reader)
 
# GPanel aufsetzen
 
makeGPanel(-20, 20, -20, 20)
drawGrid(-15, 15, -15, 15)
 
# Punkte zeichnen
 
for i in range(len(datalist)):
    move(datalist[i][0], datalist[i][1])
    if int(datalist[i][2]) == 1:
        setColor('orange')
    else:
        setColor('green')
    fillCircle(0.1)
 
 
# Funktion, die einem Punkt eine Klasse auf Grund der k naechsten Nachbarn zuweist.
 
def assignClass(point, k):
 
    # Funktion die die Distanzen vom Punkt zu den exisitierenden Punkte berechnet
 
    # Liste um die Distanzen zu speichern. Achtung: Speicher!!
 
    distlist = []
 
    #
 
    for i in range(len(datalist)):
        distlist.append([datalist[i][2], sqrt((point[0]
                        - datalist[i][0]) ** 2 + (point[1]
                        - datalist[i][1]) ** 2)])
 
    # das waere ein sehr Pythonesquer Weg mit Lambda-Funktionen
    # nearest = sorted(distlist,key=lambda result:result[1])
 
    # definiere eine Funktion, welche das zweite Element zurueckgibt. ....
 
    def sortFunction(item):
        return item[1]
 
    # Sortiere die liste! Achtung: Man koennte auch ohne key Arbeiten, wenn Distanz an 1. Stelle waere
 
    nearest = sorted(distlist, key=sortFunction)
 
    # Zaehle Klassennummern und entscheide ueber Klasse. Achtung: Laesst sich so nicht auf k>2 Klassen erweitern.
 
    classsum = sum([nearest[i][0] for i in range(k)])
    if classsum > k / 2:
        retclass = 1
    elif classsum < k / 2:
        retclass = 0
    else:
        retclass = random.randint(0, 1)
    return retclass
 
 
# Funktion um Pt zu zeichnen und mit Label auf Grund der k-naechsten Nachbarn zu versehen
 
def drawAndLablePoint(point, k):
    guessedclass = assignClass(point, k)
    move(point)
    setColor('black')
    fillRectangle(0.02, 0.02)
    text(str(guessedclass))
 
 
# Programm teseten
 
drawAndLablePoint([-0.5, 0.5], 3)
drawAndLablePoint([0.5, 0.5], 3)
print assignClass([-1, -1], 3)
print assignClass([-1, 1], 3)
print assignClass([1, -1], 3)
print assignClass([1, 1], 3)
print assignClass([-1, 0.5], 3)

Ziele

  • Klassifizierungsfehler auf Testdaten für verschiedene $k$ berechnen und Tabelle erstellen.
  • ZIP-Code Problematik verstehen: ZIP-Code → Ziffer → 16×16 Bild → Liste mit 256 Graustufen-Werten → kNN in $\mathbb{R}^{256}$.
  • Einzelne Ziffern als Grafikdatei einlesen und als 256 Zahlwerte pro Bild als Liste speichern:
    • Eine Funktion schreiben, die als Argument einen Dateinamen hat und als Rückgabewert eine Liste mit 256 Elementen.
    • Diese Funktion auf alle Dateien anwenden (siehe unten) und die Ziffer aus dem Dateinamen in eine Liste von Liste mit 256+1 Elementen speichern

Hinweise

  • Klassifizierungsfehler: $Y$ sind die wahren Klassen des Testsets, $\hat Y$ die vorhergesagten Klassen. Der Klassifizierungsfehler die relative Anzahl der falsch klassifizierten $\hat Y$. Wenn $Z=\begin{cases} 1&\text{wenn }Y\neq\hat Y\\0&\text{sonst.}\end{cases}$ dann ist der Klassifizierungsfehler $$\frac{1}{n}\sum_{i=1}^n Z.$$ NB: Der Klassifizierungsfehler ist nichts anderes als die falsch klassifizierten in Prozent!.
  • Konvertierung der Bild-Dateien zu Zahlwerten
    • Bilder können in Tigerjython mit getImage eingelesen werden.
    • Verzeichnisse können mit os.listdir() durchlaufen werden:
      listdir.py
      import os
      for filename in os.listdir("C:/temp/"):
          print(filename)
    • Mit filename.split('_', 3) kann der String “filename” aufgeteilt ("gesplitted") werden, die 3 steht dabei für das Dritte Element nach dem Split in der Liste und entspricht der Ziffer.
    • Die Graustufenwerte von 0 bis 255 sollten auf Werte zwischen -1 und 1 “umgelegt” werden.
    • Ziel ist eine Liste mit 256 + 1 Einträgen pro Bilddatei. Diese Liste könnte dann wieder als CSV Datei gespeichert werden.
    • Speicherung als CSV passiert am einfachsten über CSV schreiben:
      writecsv.py
       outcsv = open("C:/temp/outfile.csv", 'a'); 
       
      # CSV-writer konfigurieren.
      writer = csv.writer(outcsv, delimiter=',', lineterminator='\n')
       
      for item in datalist:
          #Jeden Eintrag der Datalist als Zeile ausgeben
          writer.writerow([item[0], item[1], item[2]])
       
      # Wrtier schliessen
      outcsv.close()
       

Lösungen

Liste von Bilddateien

Liste von Bilddateien

pixelist_from_directory.py
import gpanel  #um bilder einzulesen
import os  #um Verzeichnisse zu listen
import csv #um CSV-Dateien zu lesen.
 
# Pfad zu den Bilddateien
digitsdirectory = 'C:/temp/digits/train/'
 
 
def getPixeListFromFilePath(filepath):
    img = gpanel.getImage(filepath)
    w = img.getWidth()
    h = img.getHeight()
    pixellist = []
    for y in range(h):
        for x in range(w):
 
            # color is ein Objekt mit verschiedenen Attributen, u.a. red, green, blue.
            # bei grau sind rot=gruen=blau, d.h., eine Farbe auslesen reicht.
            # siehe auch https://docs.oracle.com/javase/7/docs/api/java/awt/Color.html
            color = img.getPixelColor(x, y)
 
            # umlegen auf das Intervall [-1,1] zwecks Normalisierung
            value = color.red / 255 * 2 - 1
            # an liste anhaengen
            pixellist.append(value)
 
    return pixellist
 
# Lese Ziffer aus Dateiname aus.
def getDigitFromFileName(filename):
    return int(filename.split('_', 3)[2])
 
 
# leere Liste fuer alle Trainingsdaten der Form [-0.93,0.331,....,0.99,3]
 
trainingset = []
 
# durch alle files im Ziffernverzeichnis loopen
for filename in [filename for filename in os.listdir(digitsdirectory) if filename.endswith("gif")]:
 
    # Ziffer auslesen
    currdigit = getDigitFromFileName(filename)
 
    # Pixelliste von Datei auslesen
    currpixellist = getPixeListFromFilePath(digitsdirectory + filename)
 
    # Der Pixelliste die Ziffer anhaengen
    currpixellist.append(currdigit)
 
    # Gesamte Liste dem trainingsset anhaengen.
    trainingset.append(currpixellist)
 
# Das Trainingsset kann jetzt verwendet werden
# print(trainingsset)
  • kurse/efcomputergrafik/knn.1585048463.txt.gz
  • Last modified: 2020/03/24 12:14
  • by Simon Knaus