Что такое логистическая регрессия?
Логистическая регрессия используется для прогнозирования класса, то есть вероятности. Логистическая регрессия может точно предсказать двоичный результат.
Представьте, что вы хотите предсказать, будет ли ссуда отклонена / принята, на основе многих признаков. Логистическая регрессия имеет вид 0/1. y = 0, если ссуда отклонена, y = 1, если она принята.
Модель логистической регрессии отличается от модели линейной регрессии двумя способами.
- Прежде всего, логистическая регрессия принимает только дихотомический (двоичный) ввод в качестве зависимой переменной (т. Е. Вектор из 0 и 1).
- Во-вторых, результат измеряется следующей функцией вероятностной связи, называемой сигмовидной из-за ее S-образной формы:
Выходные данные функции всегда находятся в диапазоне от 0 до 1. Проверьте изображение ниже.
Сигмоидальная функция возвращает значения от 0 до 1. Для задачи классификации нам нужен дискретный выход 0 или 1.
Чтобы преобразовать непрерывный поток в дискретное значение, мы можем установить границу решения на 0,5. Все значения выше этого порога классифицируются как 1
В этом руководстве вы узнаете
- Что такое логистическая регрессия?
- Как создать обобщенную модель лайнера (GLM)
- Шаг 1) Проверьте непрерывные переменные
- Шаг 2) Проверьте факторные переменные
- Шаг 3) Разработка функций
- Шаг 4) Итоговая статистика
- Шаг 5) Набор для обучения / тестирования
- Шаг 6) Постройте модель
- Шаг 7) Оцените производительность модели
Как создать обобщенную модель лайнера (GLM)
Давайте использовать набор данных для взрослых, чтобы проиллюстрировать логистическую регрессию. «Взрослый» - отличный набор данных для задачи классификации. Цель состоит в том, чтобы предсказать, превысит ли годовой долларовый доход человека 50 000 долларов США. Набор данных содержит 46 033 наблюдения и десять функций:
- age: возраст человека. Числовой
- образование: Образовательный уровень личности. Фактор.
- marital.status: Семейное положение человека. Фактор т.е. никогда не был в браке, женат-гражданский супруг,…
- Пол: Пол человека. Фактор, т.е. мужской или женский
- доход: целевая переменная. Доход выше или ниже 50К. Фактор ie> 50K, <= 50K
среди других
library(dplyr)data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")glimpse(data_adult)
Выход:
Observations: 48,842Variables: 10$ x1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,… $ age 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26… $ workclass Private, Private, Local-gov, Private, ?, Private,… $ education 11th, HS-grad, Assoc-acdm, Some-college, Some-col… $ educational.num 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,… $ marital.status Never-married, Married-civ-spouse, Married-civ-sp… $ race Black, White, White, Black, White, White, Black,… $ gender Male, Male, Male, Male, Female, Male, Male, Male,… $ hours.per.week 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39… $ income <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5…
Мы будем действовать следующим образом:
- Шаг 1. Проверьте непрерывные переменные
- Шаг 2: проверьте факторные переменные
- Шаг 3. Разработка функций
- Шаг 4: сводная статистика
- Шаг 5: набор для обучения / тестирования
- Шаг 6: Постройте модель
- Шаг 7. Оцените эффективность модели.
- Шаг 8: Улучшите модель
Ваша задача - предсказать, у какого человека будет доход выше 50К.
В этом руководстве каждый шаг будет подробно описан для выполнения анализа реального набора данных.
Шаг 1) Проверьте непрерывные переменные
На первом этапе вы можете увидеть распределение непрерывных переменных.
continuous <-select_if(data_adult, is.numeric)summary(continuous)
Код Пояснение
- Continuous <- select_if (data_adult, is.numeric): используйте функцию select_if () из библиотеки dplyr, чтобы выбрать только числовые столбцы
- сводка (непрерывная): распечатать сводную статистику
Выход:
## X age educational.num hours.per.week## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00## Median :23017 Median :37.00 Median :10.00 Median :40.00## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00
Из приведенной выше таблицы вы можете видеть, что данные имеют совершенно разные масштабы и часы. За неделю имеют большие выбросы (например, посмотрите на последний квартиль и максимальное значение).
Вы можете справиться с этим, выполнив два шага:
- 1. Постройте график распределения часов в неделю.
- 2: Стандартизируйте непрерывные переменные
- Постройте распределение
Давайте подробнее рассмотрим распределение часов в неделю.
# Histogram with kernel density curvelibrary(ggplot2)ggplot(continuous, aes(x = hours.per.week)) +geom_density(alpha = .2, fill = "#FF6666")
Выход:
Переменная имеет множество выбросов и нечеткое распределение. Вы можете частично решить эту проблему, удалив первые 0,01 процента часов в неделю.
Базовый синтаксис квантиля:
quantile(variable, percentile)arguments:-variable: Select the variable in the data frame to compute the percentile-percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C,… )- `A`,`B`,`C` and `… ` are all integer from 0 to 1.
Мы вычисляем верхний 2-процентный процентиль
top_one_percent <- quantile(data_adult$hours.per.week, .99)top_one_percent
Код Пояснение
- quantile (data_adult $ hours.per.week, .99): вычислить значение 99 процентов рабочего времени.
Выход:
## 99%## 80
98 процентов населения работает менее 80 часов в неделю.
Вы можете отбросить наблюдения выше этого порога. Вы используете фильтр из библиотеки dplyr.
data_adult_drop <-data_adult %>%filter(hours.per.weekВыход:
## [1] 45537 10
- Стандартизируйте непрерывные переменные
Вы можете стандартизировать каждый столбец для повышения производительности, потому что ваши данные не имеют одинаковый масштаб. Вы можете использовать функцию mutate_if из библиотеки dplyr. Основной синтаксис:
mutate_if(df, condition, funs(function))arguments:-`df`: Data frame used to compute the function- `condition`: Statement used. Do not use parenthesis- funs(function): Return the function to apply. Do not use parenthesis for the functionВы можете стандартизировать числовые столбцы следующим образом:
data_adult_rescale <- data_adult_drop % > %mutate_if(is.numeric, funs(as.numeric(scale(.))))head(data_adult_rescale)Код Пояснение
- mutate_if (is.numeric, funs (scale)): условие - это только числовой столбец, а функция - масштаб
Выход:
## X age workclass education educational.num## 1 -1.732680 -1.02325949 Private 11th -1.22106443## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494## 4 -1.732455 0.41426100 Private Some-college -0.04945081## 5 -1.732379 -0.34232873 Private 10th -1.61160231## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857## marital.status race gender hours.per.week income## 1 Never-married Black Male -0.03995944 <=50K## 2 Married-civ-spouse White Male 0.86863037 <=50K## 3 Married-civ-spouse White Male -0.03995944 >50K## 4 Married-civ-spouse Black Male -0.03995944 >50K## 5 Never-married White Male -0.94854924 <=50K## 6 Married-civ-spouse White Male -0.76683128 >50KШаг 2) Проверьте факторные переменные
Этот шаг преследует две цели:
- Проверить уровень в каждой категориальной колонке
- Определите новые уровни
Мы разделим этот шаг на три части:
- Выберите категориальные столбцы
- Храните гистограмму каждого столбца в списке
- Распечатать графики
Мы можем выбрать факторные столбцы с помощью кода ниже:
# Select categorical columnfactor <- data.frame(select_if(data_adult_rescale, is.factor))ncol(factor)Код Пояснение
- data.frame (select_if (data_adult, is.factor)): мы сохраняем факторные столбцы в factor в типе фрейма данных. Библиотеке ggplot2 требуется объект фрейма данных.
Выход:
## [1] 6Набор данных содержит 6 категориальных переменных.
Второй шаг более квалифицированный. Вы хотите построить гистограмму для каждого столбца в факторе фрейма данных. Процесс удобнее автоматизировать, особенно если столбцов много.
library(ggplot2)# Create graph for each columngraph <- lapply(names(factor),function(x)ggplot(factor, aes(get(x))) +geom_bar() +theme(axis.text.x = element_text(angle = 90)))Код Пояснение
- lapply (): используйте функцию lapply (), чтобы передать функцию во все столбцы набора данных. Вы сохраняете вывод в списке
- function (x): функция будет обрабатываться для каждого x. Здесь x - столбцы
- ggplot (factor, aes (get (x))) + geom_bar () + theme (axis.text.x = element_text (angle = 90)): создать гистограмму для каждого элемента x. Обратите внимание: чтобы вернуть x как столбец, вам нужно включить его в get ()
Последний шаг относительно прост. Вы хотите распечатать 6 графиков.
# Print the graphgraphВыход:
## [[1]]## ## [[2]]## ## [[3]]## ## [[4]]## ## [[5]]## ## [[6]]Примечание. Используйте кнопку «Далее» для перехода к следующему графику.
Шаг 3) Разработка функций
Переделать образование
Из приведенного выше графика вы можете видеть, что переменное образование имеет 16 уровней. Это существенно, и на некоторых уровнях имеется относительно небольшое количество наблюдений. Если вы хотите увеличить объем информации, которую вы можете получить из этой переменной, вы можете преобразовать ее на более высокий уровень. А именно, вы создаете большие группы с одинаковым уровнем образования. Например, низкий уровень образования превратится в отсев. Высший уровень образования будет изменен на магистр.
Вот деталь:
Старый уровень
Новый уровень
Дошкольное
выбывать
10-е
Выбывать
11-е
Выбывать
12-е
Выбывать
1-4
Выбывать
5-6 места
Выбывать
7-8 место
Выбывать
9-е
Выбывать
HS-Град
HighGrad
Несколько колледжей некоторые колледжи
Сообщество
Assoc-acdm
Сообщество
Доц-вок
Сообщество
Бакалавров
Бакалавров
Мастера
Мастера
Проф-школа
Мастера
Докторская степень
кандидат наук
recast_data <- data_adult_rescale % > %select(-X) % > %mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",ifelse(education == "Bachelors", "Bachelors",ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))Код Пояснение
- Мы используем глагол mutate из библиотеки dplyr. Мы меняем ценности образования с помощью утверждения ifelse
В приведенной ниже таблице вы создаете сводную статистику, чтобы увидеть, в среднем, сколько лет образования (z-значение) требуется, чтобы получить степень бакалавра, магистра или доктора наук.
recast_data % > %group_by(education) % > %summarize(average_educ_year = mean(educational.num),count = n()) % > %arrange(average_educ_year)Выход:
## # A tibble: 6 x 3## education average_educ_year count#### 1 dropout -1.76147258 5712## 2 HighGrad -0.43998868 14803## 3 Community 0.09561361 13407## 4 Bachelors 1.12216282 7720## 5 Master 1.60337381 3338## 6 PhD 2.29377644 557 Изменить семейное положение
Также возможно установить более низкие уровни для семейного положения. В следующем коде вы меняете уровень следующим образом:
Старый уровень
Новый уровень
Никогда не был женат
Не женат не замужем
Женат-супруг-отсутствует
Не женат не замужем
Женат-AF-супруг
Женат
Женат-гражданский-супруг
Отдельно
Отдельно
В разводе
Вдовы
Вдова
# Change level marryrecast_data <- recast_data % > %mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))Вы можете проверить количество людей в каждой группе.table(recast_data$marital.status)Выход:
## ## Married Not_married Separated Widow## 21165 15359 7727 1286Шаг 4) Итоговая статистика
Пришло время проверить статистику наших целевых переменных. На графике ниже вы подсчитываете процент людей, зарабатывающих более 50 тысяч с учетом их пола.
# Plot gender incomeggplot(recast_data, aes(x = gender, fill = income)) +geom_bar(position = "fill") +theme_classic()Выход:
Затем проверьте, влияет ли происхождение человека на его заработок.
# Plot origin incomeggplot(recast_data, aes(x = race, fill = income)) +geom_bar(position = "fill") +theme_classic() +theme(axis.text.x = element_text(angle = 90))Выход:
Количество часов работы по полу.
# box plot gender working timeggplot(recast_data, aes(x = gender, y = hours.per.week)) +geom_boxplot() +stat_summary(fun.y = mean,geom = "point",size = 3,color = "steelblue") +theme_classic()Выход:
Коробчатая диаграмма подтверждает, что распределение рабочего времени соответствует разным группам. На прямоугольной диаграмме оба пола не имеют однородных наблюдений.
Вы можете проверить плотность рабочего времени в неделю по типу образования. В дистрибутивах есть много разных вариантов. Вероятно, это можно объяснить типом контракта в США.
# Plot distribution working time by educationggplot(recast_data, aes(x = hours.per.week)) +geom_density(aes(color = education), alpha = 0.5) +theme_classic()Код Пояснение
- ggplot (recast_data, aes (x = hours.per.week)): график плотности требует только одной переменной
- geom_de density (aes (color = education), alpha = 0.5): геометрический объект для управления плотностью.
Выход:
Чтобы подтвердить свои мысли, вы можете выполнить односторонний тест ANOVA:
anova <- aov(hours.per.week~education, recast_data)summary(anova)Выход:
## Df Sum Sq Mean Sq F value Pr(>F)## education 5 1552 310.31 321.2 <2e-16 ***## Residuals 45531 43984 0.97## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1Тест ANOVA подтверждает разницу в среднем между группами.
Нелинейность
Перед запуском модели вы можете увидеть, связано ли количество отработанных часов с возрастом.
library(ggplot2)ggplot(recast_data, aes(x = age, y = hours.per.week)) +geom_point(aes(color = income),size = 0.5) +stat_smooth(method = 'lm',formula = y~poly(x, 2),se = TRUE,aes(color = income)) +theme_classic()Код Пояснение
- ggplot (recast_data, aes (x = age, y = hours.per.week)): установите эстетику графика.
- geom_point (aes (color = доход), size = 0,5): построить точечный график
- stat_smooth (): добавьте линию тренда со следующими аргументами:
- method = 'lm': Постройте аппроксимированное значение, если линейная регрессия
- формула = y ~ poly (x, 2): Подобрать полиномиальную регрессию
- se = TRUE: добавить стандартную ошибку
- aes (цвет = доход): разбейте модель по доходу.
Выход:
Вкратце, вы можете протестировать условия взаимодействия в модели, чтобы выявить эффект нелинейности между недельным рабочим временем и другими функциями. Важно определить, при каких условиях рабочее время отличается.
Корреляция
Следующая проверка - визуализировать корреляцию между переменными. Вы преобразовываете тип уровня фактора в числовой, чтобы можно было построить тепловую карту, содержащую коэффициент корреляции, вычисленный с помощью метода Спирмена.
library(GGally)# Convert data to numericcorr <- data.frame(lapply(recast_data, as.integer))# Plot the graphggcorr(corr,method = c("pairwise", "spearman"),nbreaks = 6,hjust = 0.8,label = TRUE,label_size = 3,color = "grey50")Код Пояснение
- data.frame (lapply (recast_data, as.integer)): преобразовать данные в числовые.
- ggcorr () построит тепловую карту со следующими аргументами:
- method: Метод вычисления корреляции.
- nbreaks = 6: количество перерывов
- hjust = 0.8: Управляет положением имени переменной на графике
- label = TRUE: Добавить метки в центре окон.
- label_size = 3: Этикетки с размерами
- color = "grey50"): цвет метки.
Выход:
Шаг 5) Набор для обучения / тестирования
Любая задача машинного обучения с учителем требует разделения данных между набором поездов и набором тестов. Вы можете использовать «функцию», которую вы создали в других обучающих программах с учителем, для создания набора для обучения / тестирования.
set.seed(1234)create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample <- 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}data_train <- create_train_test(recast_data, 0.8, train = TRUE)data_test <- create_train_test(recast_data, 0.8, train = FALSE)dim(data_train)Выход:
## [1] 36429 9dim(data_test)Выход:
## [1] 9108 9Шаг 6) Постройте модель
Чтобы увидеть, как работает алгоритм, вы используете пакет glm (). Обобщенная линейная модель представляет собой набор моделей. Базовый синтаксис:
glm(formula, data=data, family=linkfunction()Argument:- formula: Equation used to fit the model- data: dataset used- Family: - binomial: (link = "logit")- gaussian: (link = "identity")- Gamma: (link = "inverse")- inverse.gaussian: (link = "1/mu^2")- poisson: (link = "log")- quasi: (link = "identity", variance = "constant")- quasibinomial: (link = "logit")- quasipoisson: (link = "log")Вы готовы оценить логистическую модель, чтобы разделить уровень дохода между набором характеристик.
formula <- income~.logit <- glm(formula, data = data_train, family = 'binomial')summary(logit)Код Пояснение
- формула <- доход ~.: создание модели, которая соответствует
- logit <- glm (formula, data = data_train, family = 'binomial'): сопоставление логистической модели (family = 'binomial') с данными data_train.
- Summary (logit): распечатать сводку модели.
Выход:
#### Call:## glm(formula = formula, family = "binomial", data = data_train)## ## Deviance Residuals:## Min 1Q Median 3Q Max## -2.6456 -0.5858 -0.2609 -0.0651 3.1982#### Coefficients:## Estimate Std. Error z value Pr(>|z|)## (Intercept) 0.07882 0.21726 0.363 0.71675## age 0.41119 0.01857 22.146 < 2e-16 ***## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 ***## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 ***## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 ***## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 ***## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 ***## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 ***## educationMaster 0.35651 0.06780 5.258 1.46e-07 ***## educationPhD 0.46995 0.15772 2.980 0.00289 **## educationdropout -1.04974 0.21280 -4.933 8.10e-07 ***## educational.num 0.56908 0.07063 8.057 7.84e-16 ***## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 ***## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 ***## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 ***## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117## raceBlack 0.07188 0.19330 0.372 0.71001## raceOther 0.01370 0.27695 0.049 0.96054## raceWhite 0.34830 0.18441 1.889 0.05894 .## genderMale 0.08596 0.04289 2.004 0.04506 *## hours.per.week 0.41942 0.01748 23.998 < 2e-16 ***## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1## ## (Dispersion parameter for binomial family taken to be 1)## ## Null deviance: 40601 on 36428 degrees of freedom## Residual deviance: 27041 on 36406 degrees of freedom## AIC: 27087#### Number of Fisher Scoring iterations: 6Краткое изложение нашей модели раскрывает интересную информацию. Эффективность логистической регрессии оценивается с помощью определенных ключевых показателей.
- AIC (информационные критерии Akaike): эквивалент R2 в логистической регрессии. Он измеряет соответствие, когда штраф применяется к количеству параметров. Меньшие значения AIC указывают на то, что модель ближе к истине.
- Нулевое отклонение: подходит для модели только с перехватом. Степень свободы n-1. Мы можем интерпретировать это как значение хи-квадрат (подобранное значение отличается от проверки гипотезы фактического значения).
- Остаточное отклонение: модель со всеми переменными. Это также интерпретируется как проверка гипотезы хи-квадрат.
- Количество итераций скоринга Фишера: количество итераций перед сходимостью.
Вывод функции glm () сохраняется в списке. В приведенном ниже коде показаны все элементы, доступные в переменной logit, которую мы создали для оценки логистической регрессии.
# Список очень длинный, выведите только первые три элемента
lapply(logit, class)[1:3]Выход:
## $coefficients## [1] "numeric"#### $residuals## [1] "numeric"#### $fitted.values## [1] "numeric"Каждое значение может быть извлечено с помощью знака $, за которым следует имя метрики. Например, вы сохранили модель как logit. Чтобы извлечь критерии AIC, вы используете:
logit$aicВыход:
## [1] 27086.65Шаг 7) Оцените производительность модели
Матрица путаницы
Матрица неточностей является лучшим выбором для оценки эффективности классификации по сравнению с различными метриками вы видели раньше. Общая идея состоит в том, чтобы подсчитать количество раз, когда истинные экземпляры классифицируются как ложные.
Чтобы вычислить матрицу путаницы, вам сначала нужно иметь набор прогнозов, чтобы их можно было сравнивать с фактическими целями.
predict <- predict(logit, data_test, type = 'response')# confusion matrixtable_mat <- table(data_test$income, predict > 0.5)table_matКод Пояснение
- pred (logit, data_test, type = 'response'): вычислить прогноз для набора тестов. Установите type = 'response', чтобы вычислить вероятность ответа.
- table (data_test $ доход, предсказать> 0,5): вычислить матрицу неточностей. Прогноз> 0,5 означает, что он возвращает 1, если прогнозируемые вероятности выше 0,5, в противном случае - 0.
Выход:
#### FALSE TRUE## <=50K 6310 495## >50K 1074 1229Каждая строка в матрице неточностей представляет собой фактическую цель, а каждый столбец представляет собой прогнозируемую цель. В первой строке этой матрицы учитывается доход ниже 50k (ложный класс): 6241 были правильно классифицированы как люди с доходом ниже 50k ( истинно отрицательный ), а оставшаяся часть была ошибочно классифицирована как выше 50k ( ложноположительный результат ). Во второй строке учитывается доход выше 50 тыс., Положительный класс - 1229 ( истинно положительный ), а истинно отрицательный - 1074.
Вы можете рассчитать точность модели , суммируя истинное положительное + истинно отрицательное по общему наблюдению.
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_TestКод Пояснение
- sum (diag (table_mat)): сумма диагонали
- sum (table_mat): сумма матрицы.
Выход:
## [1] 0.8277339Модель страдает одной проблемой: она завышает количество ложноотрицательных результатов. Это называется парадоксом проверки точности . Мы заявили, что точность - это отношение правильных прогнозов к общему количеству случаев. У нас может быть относительно высокая точность, но бесполезная модель. Это случается, когда есть доминирующий класс. Если вы посмотрите на матрицу путаницы, вы увидите, что большинство случаев классифицируются как истинно отрицательные. Представьте себе, что модель классифицировала все классы как отрицательные (т.е. ниже 50k). У вас будет точность 75 процентов (6718/6718 + 2257). Ваша модель работает лучше, но изо всех сил пытается отличить истинный положительный результат от истинного отрицательного.
В такой ситуации желательно иметь более сжатую метрику. Мы можем посмотреть на:
- Точность = TP / (TP + FP)
- Напомним = TP / (TP + FN)
Точность против отзыва
Точность смотрит на точность положительного прогноза. Напоминание - это соотношение положительных примеров, которые правильно обнаружены классификатором;
Вы можете создать две функции для вычисления этих двух показателей.
- Построить точность
precision <- function(matrix) {# True positivetp <- matrix[2, 2]# false positivefp <- matrix[1, 2]return (tp / (tp + fp))}Код Пояснение
- mat [1,1]: возвращает первую ячейку первого столбца фрейма данных, т.е. истинное положительное значение.
- мат [1,2]; Вернуть первую ячейку второго столбца фрейма данных, т.е. ложное срабатывание.
recall <- function(matrix) {# true positivetp <- matrix[2, 2]# false positivefn <- matrix[2, 1]return (tp / (tp + fn))}Код Пояснение
- mat [1,1]: возвращает первую ячейку первого столбца фрейма данных, т.е. истинное положительное значение.
- мат [2,1]; Вернуть вторую ячейку первого столбца фрейма данных, т.е. ложноотрицательный результат.
Вы можете проверить свои функции
prec <- precision(table_mat)precrec <- recall(table_mat)recВыход:
## [1] 0.712877## [2] 0.5336518Когда модель говорит, что это лицо старше 50 тысяч, она верна только в 54 процентах случаев и может требовать лиц выше 50 тысяч в 72 процентах случаев.
Вы можете создать гармоническое среднее этих двух показателей, что означает, что он придает больший вес более низким значениям.
f1 <- 2 * ((prec * rec) / (prec + rec))f1Выход:
## [1] 0.6103799Компромисс между точностью и отзывчивостью
Невозможно иметь одновременно высокую точность и высокую отзывчивость.
Если мы увеличим точность, правильный человек будет лучше предсказан, но мы упустим многие из них (более низкий уровень запоминания). В некоторых ситуациях мы предпочитаем более высокую точность, чем отзыв. Между точностью и отзывом существует вогнутая взаимосвязь.
- Представьте, вам нужно предсказать, есть ли у пациента заболевание. Вы хотите быть максимально точными.
- Если вам нужно обнаружить потенциальных мошенников на улице с помощью распознавания лиц, было бы лучше поймать многих людей, помеченных как мошенники, даже если точность невысока. Полиция сможет освободить человека, не занимавшегося мошенничеством.
Кривая ROC
Характеристика приемник Операционный кривой является еще одним распространенным инструментом , используемым с бинарной классификацией. Это очень похоже на кривую точность / отзыв, но вместо построения графика зависимости точности от отзыва кривая ROC показывает истинно положительный коэффициент (то есть отзыв) против ложноположительного. Частота ложных срабатываний - это отношение отрицательных случаев, которые ошибочно классифицируются как положительные. Он равен единице минус истинная отрицательная ставка. Истинный отрицательный показатель также называется специфичностью . Следовательно, кривая ROC отображает чувствительность (отзыв) в зависимости от специфичности 1.
Чтобы построить кривую ROC, нам нужно установить библиотеку под названием RORC. Мы можем найти его в библиотеке conda. Вы можете ввести код:
conda install -cr r-rocr --yes
Мы можем построить ROC с помощью функций прогнозирования () и производительности ().
library(ROCR)ROCRpred <- prediction(predict, data_test$income)ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))Код Пояснение
- прогноз (прогноз, data_test $ доход): библиотеке ROCR необходимо создать объект прогнозирования для преобразования входных данных
- performance (ROCRpred, 'tpr', 'fpr'): возвращает две комбинации, которые необходимо произвести на графике. Здесь построены tpr и fpr. Суммируйте точность построения и напоминание вместе, используйте "prec", "rec".
Выход:
Шаг 8) Улучшаем модель
Вы можете попробовать добавить нелинейности к модели с помощью взаимодействия между
- возраст и часы в неделю
- пол и часы в неделю.
Для сравнения обеих моделей необходимо использовать оценочный тест.
formula_2 <- income~age: hours.per.week + gender: hours.per.week + .logit_2 <- glm(formula_2, data = data_train, family = 'binomial')predict_2 <- predict(logit_2, data_test, type = 'response')table_mat_2 <- table(data_test$income, predict_2 > 0.5)precision_2 <- precision(table_mat_2)recall_2 <- recall(table_mat_2)f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))f1_2Выход:
## [1] 0.6109181Оценка немного выше предыдущей. Вы можете продолжить работу с данными, чтобы побить рекорд.
Резюме
Мы можем резюмировать функцию для обучения логистической регрессии в таблице ниже:
Упаковка
Цель
функция
аргумент
-
Создать набор данных для поездов / тестов
create_train_set ()
данные, размер, поезд
glm
Обучить обобщенную линейную модель
glm ()
формула, данные, семья *
glm
Обобщите модель
резюме()
подогнанная модель
основание
Сделайте прогноз
предсказывать()
подобранная модель, набор данных, тип = 'ответ'
основание
Создайте матрицу путаницы
стол()
y, предсказать ()
основание
Создать оценку точности
сумма (диаг (таблица ()) / сумма (таблица ()
РПЦЗ
Создание ROC: Шаг 1 Создание прогноза
прогноз()
предсказать (), у
РПЦЗ
Создание ROC: Шаг 2 Создание перформанса
спектакль()
предсказание (), 'tpr', 'fpr'
РПЦЗ
Создание ROC: Шаг 3 Постройте график
участок()
спектакль()
Другой тип моделей GLM :
- бином: (link = "logit")
- гауссовский: (link = "identity")
- Гамма: (ссылка = "обратная")
- inverse.gaussian: (link = "1 / mu 2")
- пуассон: (ссылка = "журнал")
- quasi: (link = "identity", variance = "constant")
- квазибиномиальный: (link = "logit")
- квазипуассон: (link = "log")