Дерево решений в R - Дерево классификации & Код на R с примером

Содержание:

Anonim

Что такое деревья решений?

Деревья решений - это универсальный алгоритм машинного обучения, который может выполнять задачи как классификации, так и регрессии. Это очень мощные алгоритмы, способные подбирать сложные наборы данных. Кроме того, деревья решений являются фундаментальными компонентами случайных лесов, которые являются одними из самых эффективных алгоритмов машинного обучения, доступных сегодня.

Обучение и визуализация деревьев решений

Чтобы построить ваше первое дерево решений в примере R, мы будем действовать следующим образом в этом руководстве по дереву решений:

  • Шаг 1. Импортируйте данные
  • Шаг 2. Очистите набор данных
  • Шаг 3. Создайте набор поездов / тестов
  • Шаг 4: Постройте модель
  • Шаг 5. Сделайте прогноз
  • Шаг 6. Измерьте эффективность
  • Шаг 7: Настройте гиперпараметры

Шаг 1) Импортируйте данные

Если вам интересно узнать о судьбе Титаника, вы можете посмотреть это видео на Youtube. Цель этого набора данных - предсказать, какие люди с большей вероятностью выживут после столкновения с айсбергом. Набор данных содержит 13 переменных и 1309 наблюдений. Набор данных упорядочен по переменной X.

set.seed(678)path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv'titanic <-read.csv(path)head(titanic)

Выход:

## X pclass survived name sex## 1 1 1 1 Allen, Miss. Elisabeth Walton female## 2 2 1 1 Allison, Master. Hudson Trevor male## 3 3 1 0 Allison, Miss. Helen Loraine female## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female## 6 6 1 1 Anderson, Mr. Harry male## age sibsp parch ticket fare cabin embarked## 1 29.0000 0 0 24160 211.3375 B5 S## 2 0.9167 1 2 113781 151.5500 C22 C26 S## 3 2.0000 1 2 113781 151.5500 C22 C26 S## 4 30.0000 1 2 113781 151.5500 C22 C26 S## 5 25.0000 1 2 113781 151.5500 C22 C26 S## 6 48.0000 0 0 19952 26.5500 E12 S## home.dest## 1 St Louis, MO## 2 Montreal, PQ / Chesterville, ON## 3 Montreal, PQ / Chesterville, ON## 4 Montreal, PQ / Chesterville, ON## 5 Montreal, PQ / Chesterville, ON## 6 New York, NY
tail(titanic)

Выход:

## X pclass survived name sex age sibsp## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0## parch ticket fare cabin embarked home.dest## 1304 0 2627 14.4583 C## 1305 0 2665 14.4542 C## 1306 0 2665 14.4542 C## 1307 0 2656 7.2250 C## 1308 0 2670 7.2250 C## 1309 0 315082 7.8750 S

Из вывода головы и хвоста вы можете заметить, что данные не перемешиваются. Это большая проблема! Когда вы разделите свои данные между поездом и тестовым набором, вы выберете только пассажира из класса 1 и 2 (ни один пассажир из класса 3 не входит в верхние 80 процентов наблюдений), что означает, что алгоритм никогда не увидит особенности пассажира 3 класса. Эта ошибка приведет к плохому прогнозу.

Чтобы решить эту проблему, вы можете использовать функцию sample ().

shuffle_index <- sample(1:nrow(titanic))head(shuffle_index)

Дерево решений R-код Пояснение

  • sample (1: nrow (titanic)): генерирует случайный список с индексами от 1 до 1309 (т. е. максимальное количество строк).

Выход:

## [1] 288 874 1078 633 887 992 

Вы будете использовать этот индекс, чтобы перемешать титанический набор данных.

titanic <- titanic[shuffle_index, ]head(titanic)

Выход:

## X pclass survived## 288 288 1 0## 874 874 3 0## 1078 1078 3 1## 633 633 3 0## 887 887 3 1## 992 992 3 1## name sex age## 288 Sutton, Mr. Frederick male 61## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42## 1078 O'Driscoll, Miss. Bridget female NA## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39## 887 Jermyn, Miss. Annie female NA## 992 Mamee, Mr. Hanna male NA## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ## 874 0 0 348121 7.6500 F G63 S## 1078 0 0 14311 7.7500 Q## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN## 887 0 0 14313 7.7500 Q## 992 0 0 2677 7.2292 C

Шаг 2) Очистите набор данных

Структура данных показывает, что некоторые переменные имеют NA. Очистка данных должна выполняться следующим образом

  • Отбросьте переменные home.dest, cabin, name, X и ticket
  • Создать факторные переменные для pclass и выжить
  • Отбросьте NA
library(dplyr)# Drop variablesclean_titanic <- titanic % > %select(-c(home.dest, cabin, name, X, ticket)) % > %#Convert to factor levelmutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')),survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > %na.omit()glimpse(clean_titanic)

Код Пояснение

  • select (-c (home.dest, cabin, name, X, ticket)): удалить ненужные переменные
  • pclass = factor (pclass, levels = c (1,2,3), labels = c ('Upper', 'Middle', 'Lower')): добавить метку к переменной pclass. 1 становится верхним, 2 - средним, а 3 - нижним.
  • фактор (выживший, уровни = c (0,1), метки = c («Нет», «Да»)): добавить метку к выживаемой переменной. 1 становится "Нет", а 2 становится "Да"
  • na.omit (): удалить наблюдения NA

Выход:

## Observations: 1,045## Variables: 8## $ pclass  Upper, Lower, Lower, Upper, Middle, Upper, Middle, U… ## $ survived  No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y… ## $ sex  male, male, female, female, male, male, female, male… ## $ age  61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0,… ## $ sibsp  0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,… ## $ parch  0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,… ## $ fare  32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542,… ## $ embarked  S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C… 

Шаг 3) Создайте набор поездов / тестов

Перед тем, как обучить свою модель, вам необходимо выполнить два шага:

  • Создайте поезд и набор тестов: вы обучаете модель на наборе поездов и проверяете прогноз на тестовом наборе (т. Е. Невидимые данные).
  • Установите rpart.plot из консоли

Распространенной практикой является разделение данных в соотношении 80/20, 80 процентов данных служат для обучения модели, а 20 процентов - для составления прогнозов. Вам нужно создать два отдельных фрейма данных. Вы не хотите трогать набор тестов, пока не закончите построение модели. Вы можете создать функцию с именем create_train_test (), которая принимает три аргумента.

create_train_test(df, size = 0.8, train = TRUE)arguments:-df: Dataset used to train the model.-size: Size of the split. By default, 0.8. Numerical value-train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample < - 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}

Код Пояснение

  • function (data, size = 0.8, train = TRUE): добавить аргументы в функцию
  • n_row = nrow (данные): подсчитать количество строк в наборе данных
  • total_row = size * n_row: вернуть n-ю строку для построения набора поездов
  • train_sample <- 1: total_row: выберите первую строку до n-й строки
  • if (train == TRUE) {} else {}: если условие установлено в true, вернуть набор поездов, иначе набор тестов.

Вы можете проверить свою функцию и размер.

data_train <- create_train_test(clean_titanic, 0.8, train = TRUE)data_test <- create_train_test(clean_titanic, 0.8, train = FALSE)dim(data_train)

Выход:

## [1] 836 8
dim(data_test)

Выход:

## [1] 209 8 

Набор данных поезда содержит 1046 строк, а тестовый набор данных - 262 строки.

Вы используете функцию prop.table () в сочетании с table (), чтобы проверить правильность процесса рандомизации.

prop.table(table(data_train$survived))

Выход:

#### No Yes## 0.5944976 0.4055024
prop.table(table(data_test$survived))

Выход:

#### No Yes## 0.5789474 0.4210526

В обоих наборах данных количество выживших одинаково, около 40 процентов.

Установите rpart.plot

rpart.plot недоступен в библиотеках conda. Вы можете установить его из консоли:

install.packages("rpart.plot") 

Шаг 4) Постройте модель

Вы готовы построить модель. Синтаксис функции дерева решений Rpart:

rpart(formula, data=, method='')arguments:- formula: The function to predict- data: Specifies the data frame- method:- "class" for a classification tree- "anova" for a regression tree

Вы используете метод класса, потому что вы предсказываете класс.

library(rpart)library(rpart.plot)fit <- rpart(survived~., data = data_train, method = 'class')rpart.plot(fit, extra = 106

Код Пояснение

  • rpart (): функция для соответствия модели. Аргументы следующие:
    • выжил ~ .: Формула деревьев решений
    • data = data_train: набор данных
    • method = 'class': Подобрать бинарную модель
  • rpart.plot (fit, extra = 106): построить дерево. Дополнительные функции установлены на 101 для отображения вероятности 2-го класса (полезно для двоичных ответов). Вы можете обратиться к виньетке для получения дополнительной информации о других вариантах.

Выход:

Вы начинаете с корневого узла (глубина от 0 до 3, верх графика):

  1. Вверху - общая вероятность выживания. Он показывает долю пассажиров, которые выжили в аварии. 41 процент пассажиров выжил.
  2. Этот узел спрашивает, является ли пол пассажира мужчиной. Если да, то вы спускаетесь к левому дочернему узлу корня (глубина 2). 63 процента - мужчины с вероятностью выживания 21 процент.
  3. Во втором узле вы спрашиваете, старше ли пассажир мужского пола 3,5 года. Если да, то шанс выжить составляет 19 процентов.
  4. Продолжайте так же, чтобы понять, какие особенности влияют на вероятность выживания.

Обратите внимание, что одним из многих качеств деревьев решений является то, что они не требуют подготовки данных. В частности, они не требуют масштабирования или центрирования элементов.

По умолчанию функция rpart () использует меру примеси Джини для разделения заметки. Чем выше коэффициент Джини, тем больше разных экземпляров внутри узла.

Шаг 5) Сделайте прогноз

Вы можете предсказать свой тестовый набор данных. Чтобы сделать прогноз, вы можете использовать функцию Forex (). Базовый синтаксис предсказания для дерева решений R:

predict(fitted_model, df, type = 'class')arguments:- fitted_model: This is the object stored after model estimation.- df: Data frame used to make the prediction- type: Type of prediction- 'class': for classification- 'prob': to compute the probability of each class- 'vector': Predict the mean response at the node level

Вы хотите предсказать, какие пассажиры с большей вероятностью выживут после столкновения с помощью тестового набора. Значит, среди этих 209 пассажиров вы будете знать, кто выживет или нет.

predict_unseen <-predict(fit, data_test, type = 'class')

Код Пояснение

  • предсказать (fit, data_test, type = 'class'): предсказать класс (0/1) набора тестов

Тестирование не успевшего и не успевшего пассажира.

table_mat <- table(data_test$survived, predict_unseen)table_mat

Код Пояснение

  • table (data_test $ survd, pred_unseen): создайте таблицу, чтобы подсчитать, сколько пассажиров классифицировано как выжившие и скончались, сравните с правильной классификацией дерева решений в R

Выход:

## predict_unseen## No Yes## No 106 15## Yes 30 58

Модель правильно предсказала 106 мертвых пассажиров, но классифицировала 15 выживших как мертвых. По аналогии, модель ошибочно классифицировала 30 пассажиров как выживших, когда они оказались мертвыми.

Шаг 6) Измерьте производительность

Вы можете вычислить меру точности для задачи классификации с помощью матрицы неточностей :

Матрица неточностей является лучшим выбором для оценки эффективности классификации. Общая идея состоит в том, чтобы подсчитать количество раз, когда истинные экземпляры классифицируются как ложные.

Каждая строка в матрице неточностей представляет собой фактическую цель, а каждый столбец представляет собой прогнозируемую цель. В первой строке этой матрицы учитываются мертвые пассажиры (класс False): 106 были правильно классифицированы как мертвые ( истинно отрицательные ), а оставшаяся часть была ошибочно классифицирована как выжившие ( ложноположительные ). Во второй строке учитываются выжившие, положительный класс - 58 ( истинно положительный ), а истинно отрицательный - 30.

Вы можете вычислить тест на точность из матрицы неточностей:

Это соотношение истинно положительного и истинно отрицательного к сумме матрицы. С помощью R вы можете кодировать следующим образом:

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)

Код Пояснение

  • sum (diag (table_mat)): сумма диагонали
  • sum (table_mat): сумма матрицы.

Вы можете распечатать точность набора тестов:

print(paste('Accuracy for test', accuracy_Test))

Выход:

## [1] "Accuracy for test 0.784688995215311" 

Вы набрали 78 процентов за набор тестов. Вы можете повторить то же упражнение с набором данных для обучения.

Шаг 7) Настройте гиперпараметры

Дерево решений в R имеет различные параметры, которые контролируют аспекты соответствия. В библиотеке дерева решений rpart вы можете управлять параметрами с помощью функции rpart.control (). В следующем коде вы вводите параметры, которые вы настраиваете. Вы можете обратиться к виньетке, чтобы узнать о других параметрах.

rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30)Arguments:-minsplit: Set the minimum number of observations in the node before the algorithm perform a split-minbucket: Set the minimum number of observations in the final note i.e. the leaf-maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0

Мы будем действовать следующим образом:

  • Построить функцию для возврата точности
  • Настройте максимальную глубину
  • Настройте минимальное количество выборок, которое должен иметь узел, прежде чем он сможет разделиться
  • Настройте минимальное количество выборок, которое должен иметь листовой узел

Вы можете написать функцию для отображения точности. Вы просто обертываете код, который использовали раньше:

  1. предсказать: предсказать_unseen <- предсказать (соответствие, тест_данных, тип = 'класс')
  2. Таблица вывода: table_mat <- table (data_test $ survival, pred_unseen)
  3. Точность вычислений: precision_Test <- sum (diag (table_mat)) / sum (table_mat)
accuracy_tune <- function(fit) {predict_unseen <- predict(fit, data_test, type = 'class')table_mat <- table(data_test$survived, predict_unseen)accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_Test}

Вы можете попробовать настроить параметры и посмотреть, сможете ли вы улучшить модель по сравнению со значением по умолчанию. Напоминаем, что вам нужно получить точность выше 0,78.

control <- rpart.control(minsplit = 4,minbucket = round(5 / 3),maxdepth = 3,cp = 0)tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control)accuracy_tune(tune_fit)

Выход:

## [1] 0.7990431 

Со следующим параметром:

minsplit = 4minbucket= round(5/3)maxdepth = 3cp=0 

Вы получаете более высокую производительность, чем предыдущая модель. Поздравляю!

Резюме

Мы можем суммировать функции для обучения алгоритма дерева решений в R

Библиотека

Цель

функция

учебный класс

параметры

Детали

rpart

Дерево классификации поездов в R

rpart ()

учебный класс

формула, df, метод

rpart

Дерево регрессии поезда

rpart ()

анова

формула, df, метод

rpart

Постройте деревья

rpart.plot ()

подогнанная модель

основание

предсказывать

предсказывать()

учебный класс

подогнанная модель, тип

основание

предсказывать

предсказывать()

проблема

подогнанная модель, тип

основание

предсказывать

предсказывать()

вектор

подогнанная модель, тип

rpart

Параметры контроля

rpart.control ()

minsplit

Установите минимальное количество наблюдений в узле, прежде чем алгоритм выполнит разбиение

Minbucket

Задайте минимальное количество наблюдений в последней заметке, т.е. на листе

Максимальная глубина

Установите максимальную глубину любого узла окончательного дерева. Корневой узел обрабатывается глубиной 0

rpart

Модель поезда с контрольным параметром

rpart ()

формула, df, метод, контроль

Примечание. Обучите модель на обучающих данных и проверьте производительность на невидимом наборе данных, т. Е. На тестовом наборе.