# Knoten für Entscheidungsbaum class Knoten: def __init__(self): self.ist_label = None # für Blätter: True (bool) self.hat_label = None # für Blätter: Wert (bool) self.kriterium = None # speichert Kriterium (str) self.mit_kriterium = None # Knoten, falls Kriterium erfüllt self.ohne_kriterium = None # Knoten, falls Kriterium nicht erfüllt def trainiere(self, liste_von_trainingselementen): if self.alle_gleiches_label(liste_von_trainingselementen): # falls alle gleiches Label -> Blatt mit Label self.ist_label = True self.hat_label = liste_von_trainingselementen[0].hat_label else: # falls nicht: suche bestes Merkmal, teile Liste, trainiere rekursiv bestes_merkmal = self.suche_bestes_merkmal(liste_von_trainingselementen) if not bestes_merkmal == None: self.ist_label = False self.kriterium = bestes_merkmal self.mit_kriterium = Knoten() self.ohne_kriterium = Knoten() liste_mit_krit = [] liste_ohne_krit = [] for element in liste_von_trainingselementen: if element.eigenschaften.pop(bestes_merkmal): # Merkmal wird entfernt liste_mit_krit.append(element) else: liste_ohne_krit.append(element) self.mit_kriterium.trainiere(liste_mit_krit) self.ohne_kriterium.trainiere(liste_ohne_krit) else: # falls kein bestes Merkmal: Blatt mit Mehrheitslabel hat_label = 0 ohne_label = 0 for element in liste_von_trainingselementen: if element.hat_label: hat_label += 1 else: ohne_label += 1 self.ist_label = True self.hat_label = hat_label > ohne_label def alle_gleiches_label(self, liste_von_elementen): label = {e.hat_label for e in liste_von_elementen} return len(label) == 1 def suche_bestes_merkmal(self, liste_von_elementen): # Liste aller Eigenschaften anlegen alle_eigenschaften = liste_von_elementen[0].eigenschaften.keys() # kombinierte Gini-Impurity für jede Eigenschaft bestimmen kombinierte_gini_impurity = 1 bestes_merkmal = None for eigenschaft in alle_eigenschaften: hat_e_hat_l = 0 hat_e_ohne_l = 0 ohne_e_hat_l = 0 ohne_e_ohne_l = 0 for ele in liste_von_elementen: if ele.eigenschaften[eigenschaft]: if ele.hat_label: hat_e_hat_l += 1 else: hat_e_ohne_l += 1 else: if ele.hat_label: ohne_e_hat_l += 1 else: ohne_e_ohne_l += 1 gini_mit_eigenschaft = self.gini_impurity(hat_e_hat_l, hat_e_ohne_l, hat_e_hat_l + hat_e_ohne_l) gini_ohne_eigenschaft = self.gini_impurity(ohne_e_hat_l, ohne_e_ohne_l, ohne_e_hat_l + ohne_e_ohne_l) # geringste wählen if (gini_mit_eigenschaft + gini_ohne_eigenschaft) < kombinierte_gini_impurity: kombinierte_gini_impurity = gini_mit_eigenschaft + gini_ohne_eigenschaft bestes_merkmal = eigenschaft return bestes_merkmal def gini_impurity(self, mit, ohne, gesamt): return 1 - (mit/gesamt)**2 - (ohne/gesamt)**2 def bestimme_label(self, element): if self.ist_label: return self.hat_label else: if element.eigenschaften[self.kriterium]: return self.mit_kriterium.bestimme_label(element) else: return self.ohne_kriterium.bestimme_label(element) class Element: def __init__(self, eigenschaften, hat_label, id=None): self.eigenschaften = eigenschaften # dict self.hat_label = hat_label # bool self.id = id def liesAffendatei(filename): affen = [] with open(filename, mode='r', encoding='utf-8-sig') as file: eigenschaften = file.readline().strip().split(',') # Liest erste Zeile for zeile in iter(file.readline, ''): eigenschaften_aus_zeile = zeile.split(',') n = eigenschaften_aus_zeile[0] ele_eigenschaften = {e[0]:e[1].strip() == 'true' for e in zip(eigenschaften, eigenschaften_aus_zeile)} ele_eigenschaften.pop('Number') label = ele_eigenschaften.pop('beisst') affen.append(Element(ele_eigenschaften, label, id=n)) return affen def zeige_baum(wurzel): knotenliste = [wurzel] while knotenliste: knoten = knotenliste.pop(0) if knoten.ist_label: print(knoten.hat_label) else: print(knoten.kriterium) knotenliste.append(knoten.mit_kriterium) knotenliste.append(knoten.ohne_kriterium) if __name__ == '__main__': # Lies Affen aus Datei trainigs…, 'beisst' als Label extra behandeln affen = liesAffendatei('trainingtest.csv') # Baum trainieren k = Knoten() k.trainiere(affen) zeige_baum(k) # mit Affen aus Datei test… testen testaffen = liesAffendatei('testtest.csv') for affe in testaffen: if k.bestimme_label(affe) == affe.hat_label: print(f'Affe {affe.id} wurde richtig eingeordnet!') else: print(f'Affe {affe.id} wurde falsch eingeordnet!')