Files
simpler_entscheidungsbaum/Entscheidungsbäume.py

140 lines
5.7 KiB
Python
Raw Normal View History

2024-06-24 14:21:34 +02:00
# Knoten für Entscheidungsbaum
class Knoten:
def __init__(self):
self.ist_label = None # für Blätter: True (bool)
self.hat_label = None # für Blätter: Wert (bool)
self.kriterium = None # speichert Kriterium (str)
self.mit_kriterium = None # Knoten, falls Kriterium erfüllt
self.ohne_kriterium = None # Knoten, falls Kriterium nicht erfüllt
def trainiere(self, liste_von_trainingselementen):
if self.alle_gleiches_label(liste_von_trainingselementen):
# falls alle gleiches Label -> Blatt mit Label
self.ist_label = True
self.hat_label = liste_von_trainingselementen[0].hat_label
else:
# falls nicht: suche bestes Merkmal, teile Liste, trainiere rekursiv
bestes_merkmal = self.suche_bestes_merkmal(liste_von_trainingselementen)
if not bestes_merkmal == None:
self.ist_label = False
self.kriterium = bestes_merkmal
self.mit_kriterium = Knoten()
self.ohne_kriterium = Knoten()
liste_mit_krit = []
liste_ohne_krit = []
for element in liste_von_trainingselementen:
if element.eigenschaften.pop(bestes_merkmal): # Merkmal wird entfernt
liste_mit_krit.append(element)
else:
liste_ohne_krit.append(element)
self.mit_kriterium.trainiere(liste_mit_krit)
self.ohne_kriterium.trainiere(liste_ohne_krit)
else:
# falls kein bestes Merkmal: Blatt mit Mehrheitslabel
hat_label = 0
ohne_label = 0
for element in liste_von_trainingselementen:
if element.hat_label:
hat_label += 1
else:
ohne_label += 1
self.ist_label = True
self.hat_label = hat_label > ohne_label
def alle_gleiches_label(self, liste_von_elementen):
label = {e.hat_label for e in liste_von_elementen}
return len(label) == 1
def suche_bestes_merkmal(self, liste_von_elementen):
# Liste aller Eigenschaften anlegen
alle_eigenschaften = liste_von_elementen[0].eigenschaften.keys()
# kombinierte Gini-Impurity für jede Eigenschaft bestimmen
kombinierte_gini_impurity = 1
bestes_merkmal = None
for eigenschaft in alle_eigenschaften:
hat_e_hat_l = 0
hat_e_ohne_l = 0
ohne_e_hat_l = 0
ohne_e_ohne_l = 0
for ele in liste_von_elementen:
if ele.eigenschaften[eigenschaft]:
if ele.hat_label:
hat_e_hat_l += 1
else:
hat_e_ohne_l += 1
else:
if ele.hat_label:
ohne_e_hat_l += 1
else:
ohne_e_ohne_l += 1
gini_mit_eigenschaft = self.gini_impurity(hat_e_hat_l, hat_e_ohne_l, hat_e_hat_l + hat_e_ohne_l)
gini_ohne_eigenschaft = self.gini_impurity(ohne_e_hat_l, ohne_e_ohne_l, ohne_e_hat_l + ohne_e_ohne_l)
# geringste wählen
if (gini_mit_eigenschaft + gini_ohne_eigenschaft) < kombinierte_gini_impurity:
kombinierte_gini_impurity = gini_mit_eigenschaft + gini_ohne_eigenschaft
bestes_merkmal = eigenschaft
return bestes_merkmal
def gini_impurity(self, mit, ohne, gesamt):
return 1 - (mit/gesamt)**2 - (ohne/gesamt)**2
def bestimme_label(self, element):
if self.ist_label:
return self.hat_label
else:
if element.eigenschaften[self.kriterium]:
return self.mit_kriterium.bestimme_label(element)
else:
return self.ohne_kriterium.bestimme_label(element)
class Element:
def __init__(self, eigenschaften, hat_label, id=None):
self.eigenschaften = eigenschaften # dict
self.hat_label = hat_label # bool
self.id = id
def liesAffendatei(filename):
affen = []
with open(filename, mode='r', encoding='utf-8-sig') as file:
eigenschaften = file.readline().strip().split(',') # Liest erste Zeile
for zeile in iter(file.readline, ''):
eigenschaften_aus_zeile = zeile.split(',')
n = eigenschaften_aus_zeile[0]
ele_eigenschaften = {e[0]:e[1].strip() == 'true' for e in zip(eigenschaften, eigenschaften_aus_zeile)}
ele_eigenschaften.pop('Number')
label = ele_eigenschaften.pop('beisst')
affen.append(Element(ele_eigenschaften, label, id=n))
return affen
def zeige_baum(wurzel):
knotenliste = [wurzel]
while knotenliste:
knoten = knotenliste.pop(0)
if knoten.ist_label:
print(knoten.hat_label)
else:
print(knoten.kriterium)
knotenliste.append(knoten.mit_kriterium)
knotenliste.append(knoten.ohne_kriterium)
if __name__ == '__main__':
# Lies Affen aus Datei trainigs…, 'beisst' als Label extra behandeln
affen = liesAffendatei('trainingtest.csv')
# Baum trainieren
k = Knoten()
k.trainiere(affen)
zeige_baum(k)
# mit Affen aus Datei test… testen
testaffen = liesAffendatei('testtest.csv')
for affe in testaffen:
if k.bestimme_label(affe) == affe.hat_label:
print(f'Affe {affe.id} wurde richtig eingeordnet!')
else:
print(f'Affe {affe.id} wurde falsch eingeordnet!')