Flu Shot Learning: Predict H1N1 and Seasonal Flu Vaccines

一、赛题目标

流感疫苗学习:利用个体在背景、观点和健康行为方面分享的信息,预测其是否接种了H1N1 和 季节性流感疫苗.

赛题地址:Flu Shot Learning: Predict H1N1 and Seasonal Flu Vaccines

二、数据集解读与初步探索

1
2
3
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

1. 数据集解读

首先进行数据读取:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 加载特征数据集
features_df = pd.read_csv("data_sets/training_set_features.csv",
index_col='respondent_id')

# 加载标签数据集
labels_df = pd.read_csv("data_sets/training_set_labels.csv",
index_col='respondent_id')

# 加载测试集特征数据集
test_df = pd.read_csv("data_sets/test_set_features.csv",
index_col='respondent_id')

# 合并特征和标签数据集
join_df = features_df.join(labels_df)
join_df.to_csv('data_sets/join_df.csv')
  • features_df 训练数据集:

  训练数据集基本情况如下

1
2
3
# 数据整体观测
print('features_df -- shape:', features_df.shape) # 查看数据规模
features_df.head() # 查看前五条信息
features_df -- shape: (26707, 35)
h1n1_concernh1n1_knowledgebehavioral_antiviral_medsbehavioral_avoidancebehavioral_face_maskbehavioral_wash_handsbehavioral_large_gatheringsbehavioral_outside_homebehavioral_touch_facedoctor_recc_h1n1...income_povertymarital_statusrent_or_ownemployment_statushhs_geo_regioncensus_msahousehold_adultshousehold_childrenemployment_industryemployment_occupation
respondent_id
01.00.00.00.00.00.00.01.01.00.0...Below PovertyNot MarriedOwnNot in Labor ForceoxchjgsfNon-MSA0.00.0NaNNaN
13.02.00.01.00.01.00.01.01.00.0...Below PovertyNot MarriedRentEmployedbhuqouqjMSA, Not Principle City0.00.0pxcmvdjnxgwztkwe
21.01.00.01.00.00.00.00.00.0NaN...<= $75,000, Above PovertyNot MarriedOwnEmployedqufhixunMSA, Not Principle City2.00.0rucpziijxtkaffoo
31.01.00.01.00.01.01.00.00.00.0...Below PovertyNot MarriedRentNot in Labor ForcelrircsnpMSA, Principle City0.00.0NaNNaN
42.01.00.01.00.01.01.00.01.00.0...<= $75,000, Above PovertyMarriedOwnEmployedqufhixunMSA, Not Principle City1.00.0wxleyezfemcorrxb

5 rows × 35 columns

features_df 数据集的每行都是一个调查受访者,respondent_id 是唯一的随机标识符,列是与受访者对应的特征值。

1
2
# 查看数据集信息
features_df.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 26707 entries, 0 to 26706
Data columns (total 35 columns):
 #   Column                       Non-Null Count  Dtype  
---  ------                       --------------  -----  
 0   h1n1_concern                 26615 non-null  float64
 1   h1n1_knowledge               26591 non-null  float64
 2   behavioral_antiviral_meds    26636 non-null  float64
 3   behavioral_avoidance         26499 non-null  float64
 4   behavioral_face_mask         26688 non-null  float64
 5   behavioral_wash_hands        26665 non-null  float64
 6   behavioral_large_gatherings  26620 non-null  float64
 7   behavioral_outside_home      26625 non-null  float64
 8   behavioral_touch_face        26579 non-null  float64
 9   doctor_recc_h1n1             24547 non-null  float64
 10  doctor_recc_seasonal         24547 non-null  float64
 11  chronic_med_condition        25736 non-null  float64
 12  child_under_6_months         25887 non-null  float64
 13  health_worker                25903 non-null  float64
 14  health_insurance             14433 non-null  float64
 15  opinion_h1n1_vacc_effective  26316 non-null  float64
 16  opinion_h1n1_risk            26319 non-null  float64
 17  opinion_h1n1_sick_from_vacc  26312 non-null  float64
 18  opinion_seas_vacc_effective  26245 non-null  float64
 19  opinion_seas_risk            26193 non-null  float64
 20  opinion_seas_sick_from_vacc  26170 non-null  float64
 21  age_group                    26707 non-null  object 
 22  education                    25300 non-null  object 
 23  race                         26707 non-null  object 
 24  sex                          26707 non-null  object 
 25  income_poverty               22284 non-null  object 
 26  marital_status               25299 non-null  object 
 27  rent_or_own                  24665 non-null  object 
 28  employment_status            25244 non-null  object 
 29  hhs_geo_region               26707 non-null  object 
 30  census_msa                   26707 non-null  object 
 31  household_adults             26458 non-null  float64
 32  household_children           26458 non-null  float64
 33  employment_industry          13377 non-null  object 
 34  employment_occupation        13237 non-null  object 
dtypes: float64(23), object(12)
memory usage: 8.3+ MB
  • test_df 测试数据集:
1
2
3
# 数据整体观测
print('test_df -- shape:', test_df.shape) # 查看数据规模
test_df.head() # 查看前五条信息 和 submssion一致
test_df -- shape: (26708, 35)
h1n1_concernh1n1_knowledgebehavioral_antiviral_medsbehavioral_avoidancebehavioral_face_maskbehavioral_wash_handsbehavioral_large_gatheringsbehavioral_outside_homebehavioral_touch_facedoctor_recc_h1n1...income_povertymarital_statusrent_or_ownemployment_statushhs_geo_regioncensus_msahousehold_adultshousehold_childrenemployment_industryemployment_occupation
respondent_id
267072.02.00.01.00.01.01.00.01.00.0...> $75,000Not MarriedRentEmployedmlyzmhmfMSA, Not Principle City1.00.0atmlpfrshfxkjkmi
267081.01.00.00.00.00.00.00.00.00.0...Below PovertyNot MarriedRentEmployedbhuqouqjNon-MSA3.00.0atmlpfrsxqwwgdyp
267092.02.00.00.01.01.01.01.01.00.0...> $75,000MarriedOwnEmployedlrircsnpNon-MSA1.00.0nduyfdeopvmttkik
267101.01.00.00.00.00.00.00.00.01.0...<= $75,000, Above PovertyMarriedOwnNot in Labor ForcelrircsnpMSA, Not Principle City1.00.0NaNNaN
267113.01.01.01.00.01.01.01.01.00.0...<= $75,000, Above PovertyNot MarriedOwnEmployedlzgpxyitNon-MSA0.01.0fcxhlnwrmxkfnird

5 rows × 35 columns

1
test_df.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 26708 entries, 26707 to 53414
Data columns (total 35 columns):
 #   Column                       Non-Null Count  Dtype  
---  ------                       --------------  -----  
 0   h1n1_concern                 26623 non-null  float64
 1   h1n1_knowledge               26586 non-null  float64
 2   behavioral_antiviral_meds    26629 non-null  float64
 3   behavioral_avoidance         26495 non-null  float64
 4   behavioral_face_mask         26689 non-null  float64
 5   behavioral_wash_hands        26668 non-null  float64
 6   behavioral_large_gatherings  26636 non-null  float64
 7   behavioral_outside_home      26626 non-null  float64
 8   behavioral_touch_face        26580 non-null  float64
 9   doctor_recc_h1n1             24548 non-null  float64
 10  doctor_recc_seasonal         24548 non-null  float64
 11  chronic_med_condition        25776 non-null  float64
 12  child_under_6_months         25895 non-null  float64
 13  health_worker                25919 non-null  float64
 14  health_insurance             14480 non-null  float64
 15  opinion_h1n1_vacc_effective  26310 non-null  float64
 16  opinion_h1n1_risk            26328 non-null  float64
 17  opinion_h1n1_sick_from_vacc  26333 non-null  float64
 18  opinion_seas_vacc_effective  26256 non-null  float64
 19  opinion_seas_risk            26209 non-null  float64
 20  opinion_seas_sick_from_vacc  26187 non-null  float64
 21  age_group                    26708 non-null  object 
 22  education                    25301 non-null  object 
 23  race                         26708 non-null  object 
 24  sex                          26708 non-null  object 
 25  income_poverty               22211 non-null  object 
 26  marital_status               25266 non-null  object 
 27  rent_or_own                  24672 non-null  object 
 28  employment_status            25237 non-null  object 
 29  hhs_geo_region               26708 non-null  object 
 30  census_msa                   26708 non-null  object 
 31  household_adults             26483 non-null  float64
 32  household_children           26483 non-null  float64
 33  employment_industry          13433 non-null  object 
 34  employment_occupation        13282 non-null  object 
dtypes: float64(23), object(12)
memory usage: 7.3+ MB

从上述可以看出训练集和测试集的字段解释一致

  • labels_df 标签数据集:
1
2
print("labels_df -- shape:",labels_df.shape)
labels_df.head()
labels_df -- shape: (26707, 2)
h1n1_vaccineseasonal_vaccine
respondent_id
000
101
200
301
400
  • h1n1_vaccine - 受访者是否接种过 H1N1 流感疫苗。
  • seasonal_vaccine - 受访者是否接种了季节性流感疫苗。

两者都是二进制变量:0 = 否; 1 = 是。一些受访者没有接种任何一种疫苗,其他人只接种了一种疫苗,还有一些受访者两种疫苗都接种了

2. 数据质量分析

  接下来进行简单的数据探索,首先校验数据的正确性,并检验缺失值,异常值等情况。

  • 数据的正确性校验

  正确性校验,校验数据本身是否符合逻辑,此数据集将 respondent_id 作为唯一的随机标识符,将验证其是否独一无二,并且训练集和测试集的 respondent_id 无重复。

1
2
# 检验训练集 respondent_id 无重复
features_df.index.nunique() == features_df.shape[0]
True
1
2
# 判断 'respondent_id' 是否匹配
assert features_df.index.equals(labels_df.index),"索引不匹配"
1
2
# 检验测试集 respondent_id 无重复
test_df.index.nunique() == test_df.shape[0]
True
  • 检验数据缺失情况

  接下来进一步分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 计算缺失值数量
missing_values = features_df.isnull().sum(axis=0)

# 计算缺失值占比
missing_percentage = (missing_values / len(features_df)) * 100

# 获取每列的数据类型
column_types = features_df.dtypes

# 创建包含缺失值数量、缺失值占比和数据类型的 DataFrame
missing_info_df = pd.DataFrame({'miss_values': missing_values, 'miss_per': missing_percentage, 'data_type': column_types})

# 按缺失值占比降序排列
missing_info_df = missing_info_df.sort_values(by='miss_per', ascending=False)

print(missing_info_df)

                             miss_values   miss_per data_type
employment_occupation              13470  50.436215    object
employment_industry                13330  49.912008    object
health_insurance                   12274  45.957989   float64
income_poverty                      4423  16.561201    object
doctor_recc_h1n1                    2160   8.087767   float64
doctor_recc_seasonal                2160   8.087767   float64
rent_or_own                         2042   7.645936    object
employment_status                   1463   5.477965    object
marital_status                      1408   5.272026    object
education                           1407   5.268282    object
chronic_med_condition                971   3.635751   float64
child_under_6_months                 820   3.070356   float64
health_worker                        804   3.010447   float64
opinion_seas_sick_from_vacc          537   2.010709   float64
opinion_seas_risk                    514   1.924589   float64
opinion_seas_vacc_effective          462   1.729884   float64
opinion_h1n1_sick_from_vacc          395   1.479013   float64
opinion_h1n1_vacc_effective          391   1.464036   float64
opinion_h1n1_risk                    388   1.452803   float64
household_children                   249   0.932340   float64
household_adults                     249   0.932340   float64
behavioral_avoidance                 208   0.778822   float64
behavioral_touch_face                128   0.479275   float64
h1n1_knowledge                       116   0.434343   float64
h1n1_concern                          92   0.344479   float64
behavioral_large_gatherings           87   0.325757   float64
behavioral_outside_home               82   0.307036   float64
behavioral_antiviral_meds             71   0.265848   float64
behavioral_wash_hands                 42   0.157262   float64
behavioral_face_mask                  19   0.071142   float64
sex                                    0   0.000000    object
race                                   0   0.000000    object
hhs_geo_region                         0   0.000000    object
census_msa                             0   0.000000    object
age_group                              0   0.000000    object

发现缺失值挺多,尤其是缺失值占比最大的两个字段

  • employment_industry - 受访者所从事的行业类型。值表示为短随机字符串。
  • employment_occupation - 受访者的职业类型。值表示为短随机字符串。
1
2
3
employment_industry = features_df['employment_industry'].nunique()
employment_occupation = features_df['employment_occupation'].nunique()
(employment_industry,employment_occupation)
(21, 23)

可能为了保护受访者的隐私,这两个的字段的信息隐藏起来了,无法直接解读其的含义。

1
2
# 统计描述
features_df.describe().T.round(4)
countmeanstdmin25%50%75%max
h1n1_concern26615.01.61850.91030.01.02.02.03.0
h1n1_knowledge26591.01.26250.61810.01.01.02.02.0
behavioral_antiviral_meds26636.00.04880.21550.00.00.00.01.0
behavioral_avoidance26499.00.72560.44620.00.01.01.01.0
behavioral_face_mask26688.00.06900.25340.00.00.00.01.0
behavioral_wash_hands26665.00.82560.37940.01.01.01.01.0
behavioral_large_gatherings26620.00.35860.47960.00.00.01.01.0
behavioral_outside_home26625.00.33730.47280.00.00.01.01.0
behavioral_touch_face26579.00.67730.46750.00.01.01.01.0
doctor_recc_h1n124547.00.22030.41450.00.00.00.01.0
doctor_recc_seasonal24547.00.32970.47010.00.00.01.01.0
chronic_med_condition25736.00.28330.45060.00.00.01.01.0
child_under_6_months25887.00.08260.27530.00.00.00.01.0
health_worker25903.00.11190.31530.00.00.00.01.0
health_insurance14433.00.87970.32530.01.01.01.01.0
opinion_h1n1_vacc_effective26316.03.85061.00741.03.04.05.05.0
opinion_h1n1_risk26319.02.34261.28551.01.02.04.05.0
opinion_h1n1_sick_from_vacc26312.02.35771.36281.01.02.04.05.0
opinion_seas_vacc_effective26245.04.02601.08661.04.04.05.05.0
opinion_seas_risk26193.02.71921.38511.02.02.04.05.0
opinion_seas_sick_from_vacc26170.02.11811.33291.01.02.04.05.0
household_adults26458.00.88650.75340.00.01.01.03.0
household_children26458.00.53460.92820.00.00.01.03.0

三、相关特征交叉统计及可视化

首先选取 h1n1_concern (类别型特征)- 对H1N1流感的关心程度

  • 0= 完全不关心;1= 不太关心;2= 有点担心;3= 非常关心。
1
2
3
4
5
6
counts = (join_df[['h1n1_concern', 'h1n1_vaccine']]
.groupby(['h1n1_concern', 'h1n1_vaccine'])
.size()
.unstack('h1n1_vaccine')
)
counts
h1n1_vaccine01
h1n1_concern
0.02849447
1.067561397
2.081022473
3.032501341

观察 h1n1_concern 不同关注程度下,是否接种 h1n1_vaccine ,但肉眼很难看清楚。

1
2
h1n1_concern_counts = counts.sum(axis='columns') # 计算不同关心程度下受访者数量
h1n1_concern_counts
h1n1_concern
0.0     3296
1.0     8153
2.0    10575
3.0     4591
dtype: int64
1
2
props = counts.div(h1n1_concern_counts, axis='index')  # 占比统计
props
h1n1_vaccine01
h1n1_concern
0.00.8643810.135619
1.00.8286520.171348
2.00.7661470.233853
3.00.7079070.292093
1
2
3
4
5
6
7
ax = props.plot.barh(stacked=True)
ax.invert_yaxis()
ax.legend(
loc='center left',
bbox_to_anchor=(1.05, 0.5),
title='h1n1_vaccine'
)
<matplotlib.legend.Legend at 0x19c036968d0>

output_41_1

1
2
3
4
5
6
7
8
9
10
11
12
13
def vaccination_rate_plot(col, target, data, ax=None):

counts = (join_df[[target, col]]
.groupby([target, col])
.size()
.unstack(target)
)
group_counts = counts.sum(axis='columns')
props = counts.div(group_counts, axis='index')

props.plot(kind="barh", stacked=True, ax=ax)
ax.invert_yaxis()
ax.legend().remove()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
cols_to_plot = [
'h1n1_concern',
'h1n1_knowledge',
'opinion_h1n1_vacc_effective',
'opinion_h1n1_risk',
'opinion_h1n1_sick_from_vacc',
'opinion_seas_vacc_effective',
'opinion_seas_risk',
'opinion_seas_sick_from_vacc',
'sex',
'age_group',
'race',
'census_msa',
'education',
'employment_status',
'hhs_geo_region',
'household_adults',
'household_children',
'marital_status',
'income_poverty',
'rent_or_own'
]

fig, ax = plt.subplots(
len(cols_to_plot), 2, figsize=(9,len(cols_to_plot)*2.5)
)
for idx, col in enumerate(cols_to_plot):
vaccination_rate_plot(
col, 'h1n1_vaccine', join_df, ax=ax[idx, 0]
)
vaccination_rate_plot(
col, 'seasonal_vaccine', join_df, ax=ax[idx, 1]
)

ax[0, 0].legend(
loc='lower center', bbox_to_anchor=(0.5, 1.05), title='h1n1_vaccine'
)
ax[0, 1].legend(
loc='lower center', bbox_to_anchor=(0.5, 1.05), title='seasonal_vaccine'
)
fig.tight_layout()

output_43_0

四、数据预处理

1. 缺失值填充策略

将数据集的特征分为 数值型分类型有序型 三种,分别进行填充插补.

output1

下面使用 sklearn.impute.SimpleImputer 来进行缺失值插补.

1
from sklearn.impute import SimpleImputer
1
2
numeric_cols = features_df.select_dtypes('number').columns
# 数值型的列
1
2
3
categorical_cols = ['race', 'sex', 'marital_status', 'rent_or_own',  'hhs_geo_region',
'census_msa', 'employment_industry', 'employment_occupation']
# 分类型的列
1
2
3
ordinal_cols = ['age_group', 'education',  'income_poverty',
'employment_status']
# 有序型的列
1
2
# 判断有无遗漏的特征
assert len(numeric_cols)+len(ordinal_cols)+len(categorical_cols) == features_df.shape[1]
  • 在数值型列当中,基本为二元或多元变量,取某一特定值,我原本在这里取众数来填充,可是训练效果并不好,所以我选择平均值(尽管它不是连续的量),可结果是后者的训练效果更好。

  我在这里的理解,假如说一个特征(0-1),一共有十个,有7个缺失值,其中有一个1,两个0,原本的数据分布是1:2;但如果按照众数填充,就变成了1:9,势必会造成数据分布的偏移,从而影响模型训练。

1
2
3
4
5
6
7
8
9
10
# 对训练集和测试集的数值型特征进行填充
imputer = SimpleImputer(strategy='mean') # 平均值填充

numeric_cols_1 = features_df.select_dtypes('number')
numeric_cols_2 = test_df.select_dtypes('number')

filled_numeric_cols_1 = imputer.fit_transform(numeric_cols_1) # (26707, 23)
filled_numeric_cols_2 = imputer.fit_transform(numeric_cols_2)
features_df[numeric_cols_1.columns] = filled_numeric_cols_1
test_df[numeric_cols_2.columns] = filled_numeric_cols_2
  • 分类列采用,常数填充,设置值为 missing
1
2
3
4
5
6
7
8
imputer = SimpleImputer(strategy='constant', fill_value='missing')

# 对分类列进行拟合和转换
filled_categorical_cols_1 = imputer.fit_transform(features_df[categorical_cols])
filled_categorical_cols_2 = imputer.fit_transform(test_df[categorical_cols])

features_df[categorical_cols] = filled_categorical_cols_1
test_df[categorical_cols] = filled_categorical_cols_2
  • 有序列,用众数填充。
1
2
3
4
5
6
7
8
imputer = SimpleImputer(strategy='most_frequent')

# 对有序型列进行拟合和转换
filled_ordinal_cols_1 = imputer.fit_transform(features_df[ordinal_cols])
filled_ordinal_cols_2 = imputer.fit_transform(test_df[ordinal_cols])

features_df[ordinal_cols] = filled_ordinal_cols_1
test_df[ordinal_cols] = filled_ordinal_cols_2
1
features_df[numeric_cols_1.columns].corr(numeric_only=True).style.background_gradient(cmap='coolwarm').format(precision=2)
 h1n1_concernh1n1_knowledgebehavioral_antiviral_medsbehavioral_avoidancebehavioral_face_maskbehavioral_wash_handsbehavioral_large_gatheringsbehavioral_outside_homebehavioral_touch_facedoctor_recc_h1n1doctor_recc_seasonalchronic_med_conditionchild_under_6_monthshealth_workerhealth_insuranceopinion_h1n1_vacc_effectiveopinion_h1n1_riskopinion_h1n1_sick_from_vaccopinion_seas_vacc_effectiveopinion_seas_riskopinion_seas_sick_from_vacchousehold_adultshousehold_children
h1n1_concern1.000.060.090.230.160.290.250.250.250.140.130.090.050.03-0.000.240.370.360.230.330.22-0.020.05
h1n1_knowledge0.061.00-0.010.090.030.09-0.05-0.070.090.090.07-0.020.020.170.090.120.07-0.020.080.08-0.060.020.05
behavioral_antiviral_meds0.09-0.011.000.050.150.060.110.130.070.050.030.010.030.01-0.050.030.100.080.010.080.080.040.08
behavioral_avoidance0.230.090.051.000.060.340.230.220.330.070.070.04-0.000.000.020.110.120.130.120.130.080.020.04
behavioral_face_mask0.160.030.150.061.000.080.180.160.100.080.070.070.040.07-0.030.040.130.110.040.110.090.010.01
behavioral_wash_hands0.290.090.060.340.081.000.190.190.360.080.100.030.040.050.020.140.170.150.140.170.090.010.05
behavioral_large_gatherings0.25-0.050.110.230.180.191.000.580.250.080.090.100.02-0.03-0.040.050.130.180.080.130.13-0.03-0.01
behavioral_outside_home0.25-0.070.130.220.160.190.581.000.270.070.080.100.02-0.03-0.040.050.120.170.070.120.14-0.03-0.01
behavioral_touch_face0.250.090.070.330.100.360.250.271.000.080.100.030.030.070.010.100.140.130.100.140.09-0.000.02
doctor_recc_h1n10.140.090.050.070.080.080.080.070.081.000.590.150.080.100.050.150.250.110.100.190.050.000.03
doctor_recc_seasonal0.130.070.030.070.070.100.090.080.100.591.000.200.040.060.090.110.160.060.170.230.02-0.04-0.05
chronic_med_condition0.09-0.020.010.040.070.030.100.100.030.150.201.00-0.00-0.030.050.040.120.080.090.160.05-0.07-0.11
child_under_6_months0.050.020.03-0.000.040.040.020.020.030.080.04-0.001.000.08-0.020.010.090.040.000.050.040.040.10
health_worker0.030.170.010.000.070.05-0.03-0.030.070.100.06-0.030.081.000.030.050.120.010.030.09-0.020.010.04
health_insurance-0.000.09-0.050.02-0.030.02-0.04-0.040.010.050.090.05-0.020.031.000.040.00-0.020.070.04-0.05-0.06-0.05
opinion_h1n1_vacc_effective0.240.120.030.110.040.140.050.050.100.150.110.040.010.050.041.000.260.060.470.260.010.01-0.02
opinion_h1n1_risk0.370.070.100.120.130.170.130.120.140.250.160.120.090.120.000.261.000.330.220.560.200.030.09
opinion_h1n1_sick_from_vacc0.36-0.020.080.130.110.150.180.170.130.110.060.080.040.01-0.020.060.331.000.080.270.490.010.07
opinion_seas_vacc_effective0.230.080.010.120.040.140.080.070.100.100.170.090.000.030.070.470.220.081.000.34-0.02-0.02-0.08
opinion_seas_risk0.330.080.080.130.110.170.130.120.140.190.230.160.050.090.040.260.560.270.341.000.200.010.03
opinion_seas_sick_from_vacc0.22-0.060.080.080.090.090.130.140.090.050.020.050.04-0.02-0.050.010.200.49-0.020.201.000.020.06
household_adults-0.020.020.040.020.010.01-0.03-0.03-0.000.00-0.04-0.070.040.01-0.060.010.030.01-0.020.010.021.000.19
household_children0.050.050.080.040.010.05-0.01-0.010.020.03-0.05-0.110.100.04-0.05-0.020.090.07-0.080.030.060.191.00

五、模型的理解、选择 – CatBoost

  我现在面临的是一个二分类问题,我需要预测个体是否接种了 H1N1流感疫苗季节性流感疫苗. 像逻辑回归梯度提升树
常用于处理二分类问题。

  • 我这里不用 逻辑回归 ,因为此次数据集特征很多,且逻辑回归假定存在线性关系,从上述的可视化看出,较多特征与标签不存在线性关系;缺少特征选择机制,使用所有特征进行逻辑回归,可能会导致噪声特征的引入和过度拟合.

  常见的梯度提升树模型,XGBOOSTLightGBMCatBoost.

  • CatBoost (categorical boosting), 相较于另外两者,对类别型特征的处理更为完善,其自行将类别型特征处理为数值型特征,且可以丰富特征维度,其预测性强,鲁棒性强。

  • 从数据集来看,有较多的类别型特征,且缺失值占比排名前十的特征当中,有7个是类别型特征.

因此,对于这种存在大量类别型特征和缺失值的二分类问题,我使用CatBoost模型进行建模和预测。

六、模型训练、验证、测试 – CatBoost

1
2
3
4
from sklearn.impute import SimpleImputer # 用于数据预处理中的缺失值处理
from catboost import CatBoostClassifier # CatBoost分类器模型
from catboost import Pool, cv # 用于训练和评估CatBoost模型
from sklearn.metrics import roc_curve, roc_auc_score # 用于计算ROC曲线和AUC值

1. 数据集切分

output2

1
from sklearn.model_selection import train_test_split
1
2
# 将特征数据(features_df)和标签数据(labels_df)进行划分 8:2
X_train, X_test, y_train, y_test = train_test_split(features_df, labels_df, test_size=0.2, random_state=50)
1
categorical_features_indices = np.where(X_train.dtypes != float)[0] # 分类特征索引

2. 超参数优化 & 预测

1
import optuna

output3

1
2
3
4
# 构建 train_dataset_h1 数据集
train_dataset_h1 = Pool(data=X_train,
label=y_train.h1n1_vaccine,
cat_features = categorical_features_indices)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def objective1(trial):
# 定义超参数搜索空间
param = {
'iterations': trial.suggest_categorical('iterations', [100, 200, 300, 500, 1000, 1200, 1500]),
'auto_class_weights': trial.suggest_categorical('auto_class_weights', ['Balanced', 'SqrtBalanced']),
'learning_rate': trial.suggest_float("learning_rate", 0.001, 0.3),
'random_strength': trial.suggest_int("random_strength", 1, 10),
'bagging_temperature': trial.suggest_int("bagging_temperature", 0, 10),
'max_bin': trial.suggest_categorical('max_bin', [4, 5, 6, 8, 10, 20, 30]),
'grow_policy': trial.suggest_categorical('grow_policy', ['SymmetricTree', 'Depthwise', 'Lossguide']),
'min_data_in_leaf': trial.suggest_int("min_data_in_leaf", 1, 10),
'od_type': "Iter",
'od_wait': 100,
"depth": trial.suggest_int("depth", 2, 10),
"l2_leaf_reg": trial.suggest_float("l2_leaf_reg", 1e-8, 100, log=True),
'one_hot_max_size': trial.suggest_categorical('one_hot_max_size', [5, 10, 12, 100, 500, 1024]),
'custom_metric': ['Logloss', 'AUC'],
"loss_function": "Logloss",
}

# 使用 CatBoost 的交叉验证函数进行模型评估
scores = cv(train_dataset_h1,
param,
fold_count=5,
early_stopping_rounds=10,
plot=False, verbose=False)

# 返回交叉验证中 AUC 最大的分数作为优化目标
return scores['test-AUC-mean'].max()

  • 通过 optuna采样器TPESampler 进行超参数优化
1
2
3
from optuna.samplers import TPESampler
study = optuna.create_study(sampler=TPESampler(), direction="maximize")
study.optimize(objective1, n_trials=1)
[I 2023-12-28 17:35:20,258] A new study created in memory with name: no-name-60393cc1-6863-421f-9dba-06744d868143


Training on fold [0/5]

bestTest = 0.4827874453
bestIteration = 25

Training on fold [1/5]

bestTest = 0.4864917746
bestIteration = 41

Training on fold [2/5]

bestTest = 0.4929792515
bestIteration = 26

Training on fold [3/5]

bestTest = 0.5002491376
bestIteration = 38

Training on fold [4/5]


[I 2023-12-28 17:35:22,092] Trial 0 finished with value: 0.8457264063947731 and parameters: {'iterations': 1000, 'auto_class_weights': 'Balanced', 'learning_rate': 0.1465169151209682, 'random_strength': 10, 'bagging_temperature': 5, 'max_bin': 5, 'grow_policy': 'SymmetricTree', 'min_data_in_leaf': 7, 'depth': 7, 'l2_leaf_reg': 0.00016464473810433216, 'one_hot_max_size': 1024}. Best is trial 0 with value: 0.8457264063947731.


​ bestTest = 0.4890295168
​ bestIteration = 46

1
2
3
4
5
6
7
print(f"Number of finished trials: {len(study.trials)}")
print("Best trial:")
trial = study.best_trial
print(f" Value: {trial.value}")
print(" Params: ")
for key, value in trial.params.items():
print(f" {key}={value},")
Number of finished trials: 1
Best trial:
  Value: 0.8457264063947731
  Params: 
    iterations=1000,
    auto_class_weights=Balanced,
    learning_rate=0.1465169151209682,
    random_strength=10,
    bagging_temperature=5,
    max_bin=5,
    grow_policy=SymmetricTree,
    min_data_in_leaf=7,
    depth=7,
    l2_leaf_reg=0.00016464473810433216,
    one_hot_max_size=1024,

2.1 创建 CatBoost 分类器 final_model_h1 进行预测

1
2
3
4
final_model_h1 = CatBoostClassifier(verbose=False,
cat_features=categorical_features_indices,
**trial.params)
# + final_model_h1 = CatBoostClassifier(verbose=False,cat_features=categorical_features_indices, loss_function='CrossEntropy')
1
2
# 使用 final_model_h1 对 train_dataset_h1 进行训练
final_model_h1.fit(train_dataset_h1)
<catboost.core.CatBoostClassifier at 0x19c10b39490>
1
predictions_h1 = final_model_h1.predict_proba(X_test)
1
predictions_h1 = predictions_h1[:,1].reshape(-1,1)
1
from sklearn.metrics import roc_curve, roc_auc_score
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def plot_roc(y_true, y_score, label_name, ax):
fpr, tpr, thresholds = roc_curve(y_true, y_score)
ax.plot(fpr, tpr)
ax.plot([0, 1], [0, 1], color='grey', linestyle='--')
ax.set_ylabel('TPR')
ax.set_xlabel('FPR')
ax.set_title(
f"{label_name}: AUC = {roc_auc_score(y_true, y_score):.4f}"
)

fig, ax = plt.subplots(1, 1, figsize=(10, 8))
plot_roc(
y_test['h1n1_vaccine'],
predictions_h1,
'h1n1_vaccine',
ax=ax
)

output_91_0

1
roc_auc_score(y_test.h1n1_vaccine, predictions_h1)
0.8624116312772383
1
2
3
train_dataset_se = Pool(data=X_train,
label=y_train.seasonal_vaccine,
cat_features = categorical_features_indices)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def objective2(trial):
param = {
'iterations':trial.suggest_categorical('iterations', [100,200,300,500,1000,1200,1500]),
'learning_rate':trial.suggest_float("learning_rate", 0.001, 0.3),
'random_strength':trial.suggest_int("random_strength", 1,10),
'bagging_temperature':trial.suggest_int("bagging_temperature", 0,10),
'max_bin':trial.suggest_categorical('max_bin', [4,5,6,8,10,20,30]),
'grow_policy':trial.suggest_categorical('grow_policy', ['SymmetricTree', 'Depthwise', 'Lossguide']),
'min_data_in_leaf':trial.suggest_int("min_data_in_leaf", 1,10),
'od_type' : "Iter",
'od_wait' : 100,
"depth": trial.suggest_int("max_depth", 2,10),
"l2_leaf_reg": trial.suggest_float("l2_leaf_reg", 1e-8, 100, log=True),
'one_hot_max_size':trial.suggest_categorical('one_hot_max_size', [5,10,12,100,500,1024]),
'custom_metric' : ['AUC'],
"loss_function": "Logloss",
'auto_class_weights':trial.suggest_categorical('auto_class_weights', ['Balanced', 'SqrtBalanced']),
}

scores = cv(train_dataset_se,
param,
fold_count=5,
early_stopping_rounds=10,
plot=False, verbose=False)

return scores['test-AUC-mean'].max()
1
2
study2 = optuna.create_study(sampler=TPESampler(), direction="maximize")
study2.optimize(objective2, n_trials=1)
[I 2023-12-28 17:36:35,703] A new study created in memory with name: no-name-2f2861da-1f53-4a2f-8793-615f2df87f56


Training on fold [0/5]

bestTest = 0.4830003578
bestIteration = 44

Training on fold [1/5]

bestTest = 0.4672035603
bestIteration = 45

Training on fold [2/5]

bestTest = 0.4832904288
bestIteration = 45

Training on fold [3/5]

bestTest = 0.4842720356
bestIteration = 44

Training on fold [4/5]


[I 2023-12-28 17:36:38,014] Trial 0 finished with value: 0.8512337669520788 and parameters: {'iterations': 300, 'learning_rate': 0.2818115853666999, 'random_strength': 9, 'bagging_temperature': 3, 'max_bin': 20, 'grow_policy': 'Depthwise', 'min_data_in_leaf': 2, 'max_depth': 6, 'l2_leaf_reg': 0.005851321082700687, 'one_hot_max_size': 100, 'auto_class_weights': 'SqrtBalanced'}. Best is trial 0 with value: 0.8512337669520788.


​ bestTest = 0.4867891448
​ bestIteration = 48

1
2
3
4
5
6
7
print(f"Number of finished trials: {len(study2.trials)}")
print("Best trial:")
trial2 = study2.best_trial
print(f" Value: {trial.value}")
print(" Params: ")
for key, value in trial2.params.items():
print(f" {key}={value},")
Number of finished trials: 1
Best trial:
  Value: 0.8457264063947731
  Params: 
    iterations=300,
    learning_rate=0.2818115853666999,
    random_strength=9,
    bagging_temperature=3,
    max_bin=20,
    grow_policy=Depthwise,
    min_data_in_leaf=2,
    max_depth=6,
    l2_leaf_reg=0.005851321082700687,
    one_hot_max_size=100,
    auto_class_weights=SqrtBalanced,

2.2 创建 CatBoost 分类器 final_model_se 进行预测

1
2
3
4
final_model_se = CatBoostClassifier(verbose=False,
cat_features=categorical_features_indices,
**trial2.params)
# + final_model_se = CatBoostClassifier(verbose=False,cat_features=categorical_features_indices, loss_function='CrossEntropy')
1
final_model_se.fit(train_dataset_se)
<catboost.core.CatBoostClassifier at 0x19c10c8bfd0>
1
predictions_se = final_model_se.predict_proba(X_test)
1
predictions_se = predictions_se[:,1].reshape(-1,1)
1
2
3
4
5
6
7
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
plot_roc(
y_test['seasonal_vaccine'],
predictions_se,
'seasonal_vaccine',
ax=ax
)

output_102_0

1
roc_auc_score(y_test.seasonal_vaccine, predictions_se)
0.8662482553101809
1
roc_auc_score(y_test, np.hstack((predictions_h1, predictions_se)))
0.8643299432937096

七、最终预测

1. H1n1_vaccine 预测

  • 在完整数据集上预测 h1n1_vaccine
1
final_model_h1.fit(features_df, labels_df.h1n1_vaccine)
<catboost.core.CatBoostClassifier at 0x19c10b39490>
1
final_h1 = final_model_h1.predict_proba(test_df)
1
final_h1 = final_h1[:,1].reshape(-1,1)
1
final_h1
array([[0.12283344],
       [0.03105349],
       [0.20220141],
       ...,
       [0.1964233 ],
       [0.02112615],
       [0.54460577]])

2. Seasonal_vaccine 预测

  • 在完整数据集上预测 seasonal_vaccine
1
final_model_se.fit(features_df, labels_df.seasonal_vaccine)
<catboost.core.CatBoostClassifier at 0x19c10c8bfd0>
1
final_se = final_model_se.predict_proba(test_df)
1
final_se = final_se[:,1].reshape(-1,1)
1
final_se
array([[0.22118574],
       [0.02546254],
       [0.7162168 ],
       ...,
       [0.20864703],
       [0.29718237],
       [0.59726648]])
1
submission_df = pd.read_csv("data_sets/submission_format.csv",index_col="respondent_id")
1
2
3
4
5
6
7
# 确保索引值一致
np.testing.assert_array_equal(test_df.index.values, submission_df.index.values)

submission_df["h1n1_vaccine"] = final_h1
submission_df["seasonal_vaccine"] = final_se

submission_df.head()
h1n1_vaccineseasonal_vaccine
respondent_id
267070.1228330.221186
267080.0310530.025463
267090.2022010.716217
267100.6289000.886977
267110.3745360.486767
1
2
date = pd.Timestamp.now().strftime(format='%Y-%m-%d_%H-%M_')
submission_df.to_csv(f'predictions/{date}submssion_catboost_optuna.csv', index=True)

八、复盘总结

1. 模型历史性能变化

在近一个月时间里,我进行了30多次提交,我从6次较为显著变化里,总结一下。

1
my_score = pd.DataFrame({'score':[0.8185,0.8600,0.8607,0.8609,0.7878,0.8627]})
1
2
3
4
5
6
7
8
9
10
import matplotlib.pyplot as plt
import pandas as pd
plt.figure(figsize=(10, 6))
plt.plot(my_score['score'], marker='o', linestyle='-', color='b', label='Model Performance')
plt.title('Model Performance Over Time')
plt.xlabel('Iterations')
plt.ylabel('Score')
plt.grid(True)
plt.legend()
plt.show()

output_125_0

  • score = 0.8185 : 我参考基准博客,数值型列采用中位数填充,使用逻辑回归模型预测。
  • score = 0.8600 : 使用optuna进行超参数优化,测验轮数100+,耗时(1h+),使用梯度提升树模型 catboost,性能显著高于逻辑回归。
  • score = 0.8607 : 在前者基础上,将缺失值填充策略进行更改,使用SimpleImputer类分别对数值列、分类列、有序列进行,平均值、常数、众数填充。
  • score = 0.8609 : 训练集:测试集划分从 7:3 -> 8:2, cv交叉验证折叠数目从 7->5
  • score = 0.7878 :将数值型列,与标签有较强关系的,如h1n1_concernopinion_seas_risk等多元变量,转为分类特征,以供模型训练,可能是相关性过导致过高拟合,roc-auc还行,提交到网站上,分数很低,还没弄清楚原因。
  • score = 0.8627 : 我决定把optuna超参数移除,直接使用CatBoostClassifier(verbose=False,cat_features=categorical_features_indices, loss_function='CrossEntropy'),效果挺好。

2. 最终排名


我的最终排名是: 158

  • 在总参赛者排名:158/5488 2.88%
  • 提交分数者排名:158/1773 8.89%

my_rank

1
2
import seaborn as sns
import matplotlib.pyplot as plt
1
rank_df = pd.read_csv('rank.csv')
1
2
3
4
5
6
7
8
9
10
11
12
plt.figure(figsize=(10, 6))
sns.scatterplot(x=rank_df.index, y=rank_df['Score'], color='coral', alpha=0.7)

target_rank = 158
plt.scatter(x=target_rank, y=rank_df['Score'].iloc[target_rank], color='red', s=100, label='Rank 158')

plt.title('Final rank:158')
plt.xlabel('Index')
plt.ylabel('Score')
plt.legend()
plt.show()

output_132_0


  还算是一个不错的成绩,更深一步的数据探索分析,进行特征工程,模型调优,增加optuna轮数等等,都可以提高模型性能。总之,这个模型还有很大的进步空间。