140 lines
5.7 KiB
Python
140 lines
5.7 KiB
Python
|
|
# Knoten für Entscheidungsbaum
|
||
|
|
class Knoten:
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.ist_label = None # für Blätter: True (bool)
|
||
|
|
self.hat_label = None # für Blätter: Wert (bool)
|
||
|
|
self.kriterium = None # speichert Kriterium (str)
|
||
|
|
self.mit_kriterium = None # Knoten, falls Kriterium erfüllt
|
||
|
|
self.ohne_kriterium = None # Knoten, falls Kriterium nicht erfüllt
|
||
|
|
|
||
|
|
def trainiere(self, liste_von_trainingselementen):
|
||
|
|
if self.alle_gleiches_label(liste_von_trainingselementen):
|
||
|
|
# falls alle gleiches Label -> Blatt mit Label
|
||
|
|
self.ist_label = True
|
||
|
|
self.hat_label = liste_von_trainingselementen[0].hat_label
|
||
|
|
else:
|
||
|
|
# falls nicht: suche bestes Merkmal, teile Liste, trainiere rekursiv
|
||
|
|
bestes_merkmal = self.suche_bestes_merkmal(liste_von_trainingselementen)
|
||
|
|
if not bestes_merkmal == None:
|
||
|
|
self.ist_label = False
|
||
|
|
self.kriterium = bestes_merkmal
|
||
|
|
self.mit_kriterium = Knoten()
|
||
|
|
self.ohne_kriterium = Knoten()
|
||
|
|
liste_mit_krit = []
|
||
|
|
liste_ohne_krit = []
|
||
|
|
for element in liste_von_trainingselementen:
|
||
|
|
if element.eigenschaften.pop(bestes_merkmal): # Merkmal wird entfernt
|
||
|
|
liste_mit_krit.append(element)
|
||
|
|
else:
|
||
|
|
liste_ohne_krit.append(element)
|
||
|
|
self.mit_kriterium.trainiere(liste_mit_krit)
|
||
|
|
self.ohne_kriterium.trainiere(liste_ohne_krit)
|
||
|
|
else:
|
||
|
|
# falls kein bestes Merkmal: Blatt mit Mehrheitslabel
|
||
|
|
hat_label = 0
|
||
|
|
ohne_label = 0
|
||
|
|
for element in liste_von_trainingselementen:
|
||
|
|
if element.hat_label:
|
||
|
|
hat_label += 1
|
||
|
|
else:
|
||
|
|
ohne_label += 1
|
||
|
|
self.ist_label = True
|
||
|
|
self.hat_label = hat_label > ohne_label
|
||
|
|
|
||
|
|
|
||
|
|
def alle_gleiches_label(self, liste_von_elementen):
|
||
|
|
label = {e.hat_label for e in liste_von_elementen}
|
||
|
|
return len(label) == 1
|
||
|
|
|
||
|
|
def suche_bestes_merkmal(self, liste_von_elementen):
|
||
|
|
# Liste aller Eigenschaften anlegen
|
||
|
|
alle_eigenschaften = liste_von_elementen[0].eigenschaften.keys()
|
||
|
|
# kombinierte Gini-Impurity für jede Eigenschaft bestimmen
|
||
|
|
kombinierte_gini_impurity = 1
|
||
|
|
bestes_merkmal = None
|
||
|
|
for eigenschaft in alle_eigenschaften:
|
||
|
|
hat_e_hat_l = 0
|
||
|
|
hat_e_ohne_l = 0
|
||
|
|
ohne_e_hat_l = 0
|
||
|
|
ohne_e_ohne_l = 0
|
||
|
|
for ele in liste_von_elementen:
|
||
|
|
if ele.eigenschaften[eigenschaft]:
|
||
|
|
if ele.hat_label:
|
||
|
|
hat_e_hat_l += 1
|
||
|
|
else:
|
||
|
|
hat_e_ohne_l += 1
|
||
|
|
else:
|
||
|
|
if ele.hat_label:
|
||
|
|
ohne_e_hat_l += 1
|
||
|
|
else:
|
||
|
|
ohne_e_ohne_l += 1
|
||
|
|
gini_mit_eigenschaft = self.gini_impurity(hat_e_hat_l, hat_e_ohne_l, hat_e_hat_l + hat_e_ohne_l)
|
||
|
|
gini_ohne_eigenschaft = self.gini_impurity(ohne_e_hat_l, ohne_e_ohne_l, ohne_e_hat_l + ohne_e_ohne_l)
|
||
|
|
# geringste wählen
|
||
|
|
if (gini_mit_eigenschaft + gini_ohne_eigenschaft) < kombinierte_gini_impurity:
|
||
|
|
kombinierte_gini_impurity = gini_mit_eigenschaft + gini_ohne_eigenschaft
|
||
|
|
bestes_merkmal = eigenschaft
|
||
|
|
return bestes_merkmal
|
||
|
|
|
||
|
|
def gini_impurity(self, mit, ohne, gesamt):
|
||
|
|
return 1 - (mit/gesamt)**2 - (ohne/gesamt)**2
|
||
|
|
|
||
|
|
def bestimme_label(self, element):
|
||
|
|
if self.ist_label:
|
||
|
|
return self.hat_label
|
||
|
|
else:
|
||
|
|
if element.eigenschaften[self.kriterium]:
|
||
|
|
return self.mit_kriterium.bestimme_label(element)
|
||
|
|
else:
|
||
|
|
return self.ohne_kriterium.bestimme_label(element)
|
||
|
|
|
||
|
|
class Element:
|
||
|
|
|
||
|
|
def __init__(self, eigenschaften, hat_label, id=None):
|
||
|
|
self.eigenschaften = eigenschaften # dict
|
||
|
|
self.hat_label = hat_label # bool
|
||
|
|
self.id = id
|
||
|
|
|
||
|
|
def liesAffendatei(filename):
|
||
|
|
affen = []
|
||
|
|
with open(filename, mode='r', encoding='utf-8-sig') as file:
|
||
|
|
eigenschaften = file.readline().strip().split(',') # Liest erste Zeile
|
||
|
|
for zeile in iter(file.readline, ''):
|
||
|
|
eigenschaften_aus_zeile = zeile.split(',')
|
||
|
|
n = eigenschaften_aus_zeile[0]
|
||
|
|
ele_eigenschaften = {e[0]:e[1].strip() == 'true' for e in zip(eigenschaften, eigenschaften_aus_zeile)}
|
||
|
|
ele_eigenschaften.pop('Number')
|
||
|
|
label = ele_eigenschaften.pop('beisst')
|
||
|
|
affen.append(Element(ele_eigenschaften, label, id=n))
|
||
|
|
return affen
|
||
|
|
|
||
|
|
def zeige_baum(wurzel):
|
||
|
|
knotenliste = [wurzel]
|
||
|
|
while knotenliste:
|
||
|
|
knoten = knotenliste.pop(0)
|
||
|
|
if knoten.ist_label:
|
||
|
|
print(knoten.hat_label)
|
||
|
|
else:
|
||
|
|
print(knoten.kriterium)
|
||
|
|
knotenliste.append(knoten.mit_kriterium)
|
||
|
|
knotenliste.append(knoten.ohne_kriterium)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
# Lies Affen aus Datei trainigs…, 'beisst' als Label extra behandeln
|
||
|
|
affen = liesAffendatei('trainingtest.csv')
|
||
|
|
# Baum trainieren
|
||
|
|
k = Knoten()
|
||
|
|
k.trainiere(affen)
|
||
|
|
zeige_baum(k)
|
||
|
|
# mit Affen aus Datei test… testen
|
||
|
|
testaffen = liesAffendatei('testtest.csv')
|
||
|
|
for affe in testaffen:
|
||
|
|
if k.bestimme_label(affe) == affe.hat_label:
|
||
|
|
print(f'Affe {affe.id} wurde richtig eingeordnet!')
|
||
|
|
else:
|
||
|
|
print(f'Affe {affe.id} wurde falsch eingeordnet!')
|
||
|
|
|
||
|
|
|
||
|
|
|