博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
达观数据比赛 第五天任务
阅读量:4088 次
发布时间:2019-05-25

本文共 3312 字,大约阅读时间需要 11 分钟。

今天学习lightbgm算法,算是一种比较不错的轻量级的集成算法把~

【任务 3.2】LightGBM模型 时常: 2天

构建LightGBM的模型(包括:模型构建&调参&性能评估),学习理论并用Task2的特征实践

 

LightGBM有更快的训练速度和更高的效率,这是因为它是一种使用基于直方图的算法

例如,它将连续的特征值分桶(buckets)装进离散的箱子(bins),这是的训练过程中变得更快。

它与XGBoost的不同之处在于分裂节点的方式不一样。LGB避免了对整层节点分裂法,而采用了对增益最大的节点进行深入分解的方法。这样节省了大量分裂节点的资源。下图一是XGBoost的分裂方式,图二是LightGBM的分裂方式。

XGBoost分裂方式:

 

                                            

LightGBM分裂方式:

                                            

 

除此之外LightGMB还有如下的一些优势:

(1) 更低的内存占用:使用离散的箱子(bins)保存并替换连续值导致更少的内存占用。

(2) 更高的准确率(相比于其他任何提升算法):它通过leaf-wise分裂方法产生比level-wise分裂方法更复杂的树,这就是实现更高准确率的主要因素。然而,它有时候或导致过拟合,但是我们可以通过设置 max-depth参数来防止过拟合的发生。

(3) 大数据处理能力:相比于XGBoost,由于它在训练时间上的缩减,它同样能够具有处理大数据的能力。

(4) 支持并行学习。

我们都知道,XGB一共有三类参数通用参数,学习目标参数,Booster参数。对于LightGBM,有核心参数,学习控制参数,IO参数,目标参数,度量参数,网络参数,GPU参数,模型参数,这里我们常常修改得是核心参数,学习控制参数,度量参数等。具体的参数信息请参看参考文献。

最后提出代码供大家学习,欢迎大家评论及私信交流,会及时回复。这里有两种方法来使用LightGBM模型,大家可以都看看吧。Method2和我们习惯性使用sklearn中模型的方法更加贴切。

import pandas as pdfrom sklearn.model_selection import train_test_splitfrom sklearn.feature_extraction.text import TfidfVectorizerimport gensimimport timeimport pickleimport numpy as npimport csv,sysfrom sklearn import svmfrom sklearn.metrics import accuracy_score, precision_score, recall_score, f1_scorefrom sklearn.linear_model import LogisticRegressionimport lightgbm as lgbfrom sklearn.model_selection import GridSearchCV# read datadf = pd.read_csv('data/train_set.csv', nrows=5000)df.drop(columns='article', inplace=True)# # observe data# print(df['class'].value_counts(normalize=True, ascending=False))# TF-IDFvectorizer = TfidfVectorizer(ngram_range=(1, 2), min_df=3, max_df=0.9, sublinear_tf=True)vectorizer.fit(df['word_seg'])x_train = vectorizer.transform(df['word_seg'])# split training set and validation setpredictor = ['word_seg']x_train, x_validation, y_train, y_validation = train_test_split(x_train, df['class'], test_size=0.2)clf = svm.LinearSVC(C=5, dual=False)clf = LogisticRegression(C=120, dual=True)# Method 1 for creating model # # create dataset for lightgbm# lgb_train = lgb.Dataset(x_train, y_train)# # specify your configurations as a dict# params = {#     'boosting_type': 'gbdt',#     'objective': 'multiclass',#     'metric': 'multi_error',#     'num_leaf': 31,#     'learning_rate': 0.05,#     'feature_fraction': 0.9,#     'bagging_fraction': 0.8,#     'bagging_freq': 5,#     'num_class': 20,# }# print('Start training...')# # train# gbm = lgb.train(params,#                 lgb_train,#                 num_boost_round=20)# y_prediction = gbm.predict(x_validation, num_iteration=gbm.best_iteration)# result = []# for pred in y_prediction:#     result.append(int(np.argmax(pred)))# clf.fit(x_train, y_train)# y_prediction = clf.predict(x_validation)# Method 2 for creating model # # create dataset for lightgbm# lgb = lgb.sklearn.LGBMClassifier(num_leaves=30, learning_rate=0.1, n_estimators=20)# lgb.fit(x_train, y_train)# y_prediction = lgb.predict(x_validation)# label = []# for i in range(1, 20):#     label.append(i)# f1 = f1_score(y_validation, y_prediction, labels=label, average='micro')# print('The F1 Score: ' + str("%.4f" % f1))# grid search for better parameters for model lgb = lgb.sklearn.LGBMClassifier()param_grid = {    'learning_rate': [0.01, 0.1, 0.5],    'n_estimators': [30, 40]}gbm = GridSearchCV(lgb, param_grid)gbm.fit(x_train, y_train)print('网格搜索得到的最优参数是:', gbm.best_params_)

F1测试结果如下所示:

 

参考文献:

1. 掘金 香橙云子 

2.CSDN Ghost_Hzp 

 

转载地址:http://fluii.baihongyu.com/

你可能感兴趣的文章
elasticSearch安装部署
查看>>
elasticSearch基本使用
查看>>
HBase读写的几种方式(一)java篇
查看>>
Jetson Nano安装pytorch 基于torch1.6和torchvision0.7
查看>>
【Jetson-Nano】2.Tensorflow和Pytorch的安装
查看>>
ubuntu 系统下的Caffe环境搭建
查看>>
Yolov5系列AI常见数据集(1)车辆,行人,自动驾驶,人脸,烟雾
查看>>
【Jetson-Nano】2.Tensorflow object API和Pytorch的安装
查看>>
荔枝派 Nano 全志 F1C100s 编译运行 Linux ubuntu并升级gcc
查看>>
C++ STL 四种智能指针
查看>>
基于sympy的python实现三层BP神经网络算法
查看>>
玩玩机器学习1——ubuntu16.04 64位安装TensorFlow GPU+python3+cuda8.0+cudnn8.0
查看>>
CentOS7 搭建Pulsar 消息队列环境,CentOS(Linux)部署Pulsar,亲测成功,以及Python操作Pulsar实例驱动
查看>>
Git报错: OpenSSL SSL_connect: SSL_ERROR_SYSCALL in connection to github.com:443
查看>>
Java学习笔记--带有验证码的登录案例
查看>>
数据结构与算法学习笔记(1)--数组
查看>>
jdk8和jdk11不能随意切换的问题
查看>>
2020-12-29
查看>>
2021-01-16
查看>>
AndroidStudio学习(二)-模拟小相册
查看>>