diff --git a/Entscheidungsbäume.py b/Entscheidungsbäume.py new file mode 100644 index 0000000..4bb63ed --- /dev/null +++ b/Entscheidungsbäume.py @@ -0,0 +1,140 @@ +# Knoten für Entscheidungsbaum +class Knoten: + + def __init__(self): + self.ist_label = None # für Blätter: True (bool) + self.hat_label = None # für Blätter: Wert (bool) + self.kriterium = None # speichert Kriterium (str) + self.mit_kriterium = None # Knoten, falls Kriterium erfüllt + self.ohne_kriterium = None # Knoten, falls Kriterium nicht erfüllt + + def trainiere(self, liste_von_trainingselementen): + if self.alle_gleiches_label(liste_von_trainingselementen): + # falls alle gleiches Label -> Blatt mit Label + self.ist_label = True + self.hat_label = liste_von_trainingselementen[0].hat_label + else: + # falls nicht: suche bestes Merkmal, teile Liste, trainiere rekursiv + bestes_merkmal = self.suche_bestes_merkmal(liste_von_trainingselementen) + if not bestes_merkmal == None: + self.ist_label = False + self.kriterium = bestes_merkmal + self.mit_kriterium = Knoten() + self.ohne_kriterium = Knoten() + liste_mit_krit = [] + liste_ohne_krit = [] + for element in liste_von_trainingselementen: + if element.eigenschaften.pop(bestes_merkmal): # Merkmal wird entfernt + liste_mit_krit.append(element) + else: + liste_ohne_krit.append(element) + self.mit_kriterium.trainiere(liste_mit_krit) + self.ohne_kriterium.trainiere(liste_ohne_krit) + else: + # falls kein bestes Merkmal: Blatt mit Mehrheitslabel + hat_label = 0 + ohne_label = 0 + for element in liste_von_trainingselementen: + if element.hat_label: + hat_label += 1 + else: + ohne_label += 1 + self.ist_label = True + self.hat_label = hat_label > ohne_label + + + def alle_gleiches_label(self, liste_von_elementen): + label = {e.hat_label for e in liste_von_elementen} + return len(label) == 1 + + def suche_bestes_merkmal(self, liste_von_elementen): + # Liste aller Eigenschaften anlegen + alle_eigenschaften = liste_von_elementen[0].eigenschaften.keys() + # kombinierte Gini-Impurity für jede Eigenschaft bestimmen + kombinierte_gini_impurity = 1 + bestes_merkmal = None + for eigenschaft in alle_eigenschaften: + hat_e_hat_l = 0 + hat_e_ohne_l = 0 + ohne_e_hat_l = 0 + ohne_e_ohne_l = 0 + for ele in liste_von_elementen: + if ele.eigenschaften[eigenschaft]: + if ele.hat_label: + hat_e_hat_l += 1 + else: + hat_e_ohne_l += 1 + else: + if ele.hat_label: + ohne_e_hat_l += 1 + else: + ohne_e_ohne_l += 1 + gini_mit_eigenschaft = self.gini_impurity(hat_e_hat_l, hat_e_ohne_l, hat_e_hat_l + hat_e_ohne_l) + gini_ohne_eigenschaft = self.gini_impurity(ohne_e_hat_l, ohne_e_ohne_l, ohne_e_hat_l + ohne_e_ohne_l) + # geringste wählen + if (gini_mit_eigenschaft + gini_ohne_eigenschaft) < kombinierte_gini_impurity: + kombinierte_gini_impurity = gini_mit_eigenschaft + gini_ohne_eigenschaft + bestes_merkmal = eigenschaft + return bestes_merkmal + + def gini_impurity(self, mit, ohne, gesamt): + return 1 - (mit/gesamt)**2 - (ohne/gesamt)**2 + + def bestimme_label(self, element): + if self.ist_label: + return self.hat_label + else: + if element.eigenschaften[self.kriterium]: + return self.mit_kriterium.bestimme_label(element) + else: + return self.ohne_kriterium.bestimme_label(element) + +class Element: + + def __init__(self, eigenschaften, hat_label, id=None): + self.eigenschaften = eigenschaften # dict + self.hat_label = hat_label # bool + self.id = id + +def liesAffendatei(filename): + affen = [] + with open(filename, mode='r', encoding='utf-8-sig') as file: + eigenschaften = file.readline().strip().split(',') # Liest erste Zeile + for zeile in iter(file.readline, ''): + eigenschaften_aus_zeile = zeile.split(',') + n = eigenschaften_aus_zeile[0] + ele_eigenschaften = {e[0]:e[1].strip() == 'true' for e in zip(eigenschaften, eigenschaften_aus_zeile)} + ele_eigenschaften.pop('Number') + label = ele_eigenschaften.pop('beisst') + affen.append(Element(ele_eigenschaften, label, id=n)) + return affen + +def zeige_baum(wurzel): + knotenliste = [wurzel] + while knotenliste: + knoten = knotenliste.pop(0) + if knoten.ist_label: + print(knoten.hat_label) + else: + print(knoten.kriterium) + knotenliste.append(knoten.mit_kriterium) + knotenliste.append(knoten.ohne_kriterium) + + +if __name__ == '__main__': + # Lies Affen aus Datei trainigs…, 'beisst' als Label extra behandeln + affen = liesAffendatei('trainingtest.csv') + # Baum trainieren + k = Knoten() + k.trainiere(affen) + zeige_baum(k) + # mit Affen aus Datei test… testen + testaffen = liesAffendatei('testtest.csv') + for affe in testaffen: + if k.bestimme_label(affe) == affe.hat_label: + print(f'Affe {affe.id} wurde richtig eingeordnet!') + else: + print(f'Affe {affe.id} wurde falsch eingeordnet!') + + + \ No newline at end of file diff --git a/testtest.csv b/testtest.csv new file mode 100644 index 0000000..de366bd --- /dev/null +++ b/testtest.csv @@ -0,0 +1,5 @@ +Number,auge auf,zähne,kreuzaugen,beisst +6,false,true,true,true +7,true,false,false,false +8,false,false,true,true +9,false,true,false,true \ No newline at end of file diff --git a/trainingtest.csv b/trainingtest.csv new file mode 100644 index 0000000..46705f7 --- /dev/null +++ b/trainingtest.csv @@ -0,0 +1,6 @@ +Number,auge auf,zähne,kreuzaugen,beisst +1,true,true,false,true +2,false,false,false,false +3,false,true,true,true +4,true,false,false,false +5,false,false,true,true \ No newline at end of file