失眠网,内容丰富有趣,生活中的好帮手!
失眠网 > 基于flink使用K-Means算法对KDD CUP99数据集进行聚类分析

基于flink使用K-Means算法对KDD CUP99数据集进行聚类分析

时间:2020-11-08 17:05:18

相关推荐

基于flink使用K-Means算法对KDD CUP99数据集进行聚类分析

1、算法简介

kmeans算法又称k均值算法,是一种聚类算法,属于无监督学习算法。

对于给定的样本集,kmeans将其中相似的样本成员分类组织到一起,最终将样本集划分成K个簇,每个簇内的样本成员相似度比较高。

2、基本功能

使用K-Means算法对KDD CUP99网络入侵检测数据集进行聚类分析 。本程序先对输入数据集进行特征转换、归一化处理,然后基于flink通过kmeans将数据集聚成两类,实现对正常点和异常点的区分,用于检测入侵异常数据。

3、环境依赖

flink-1.9.1

4、算法流程

1.随机选取K个聚类中心。(本例中两个,用于区分正常点和异常点)。

2.计算每个样本成员到聚类中心的距离,并将其分配到最近的聚类中。

3.计算每个聚类的样本均值,并将样本均值更新为新的聚类中心。

4.重复步骤2、3,直到聚类中心移动的距离小于给定阈值。

5.输出最终的聚类中心及其样本成员。

5、提交任务时可指定命令行参数

pointFile: 入侵检测数据点文件路径

outputPath: 结果输出目录

maxIterations: 算法最大迭代次数

disDiff: 迭代终止条件,即:每次迭代前后,簇中心的距离差

kNum: K值,即簇的个数。

注意:如果没有指定将使用KMeansConstant类中的默认参数

6、代码实现

多维数据点

package cn.xsy.algorithm.kmeans;import java.io.Serializable;import java.util.ArrayList;import java.util.Arrays;import java.util.List;/*** 多维数据点*/public class Point implements Serializable {//源特征值public List<String> sourceFields;//处理后的特征值public List<Double> handledFields;//特征最大值,用于归一化处理public List<Double> fieldsMaxValue = new ArrayList<Double>(42);//特征最小值,用于归一化处理public List<Double> fieldsMinValue = new ArrayList<Double>(42);//记录当前簇中数据点的个数,用于求新的簇类中心点的除法运算public Long number = 1L;public Point(){}public Point(List<String> list){sourceFields = list;}//将字符型特征转换为数值型特征public Point featureHandled(){handledFields = new ArrayList<Double>();for (int i = 0; i < sourceFields.size(); i++) {if(i == 1){//协议类型特征转换handledFields.add((double)(Arrays.asList(KMeansConstant.PROTOCOLS).indexOf(sourceFields.get(i))));}else if(i == 2){//网络服务类型特征转换List<String> sercices = new ArrayList<String>(Arrays.asList(KMeansConstant.SERVICES));int index = sercices.indexOf(sourceFields.get(i));if(index == -1){sercices.add(sourceFields.get(i));handledFields.add((double) (sercices.indexOf(sourceFields.get(i))));} else {handledFields.add((double)(index));}}else if(i == 3){//连接状态特征转换handledFields.add((double)(Arrays.asList(KMeansConstant.FLAGS).indexOf(sourceFields.get(i))));}else if(i == 41){//标识类型特征转换List<String> labels = new ArrayList<String>(Arrays.asList(KMeansConstant.LABELS));int index = labels.indexOf(sourceFields.get(i));if(index == -1){labels.add(sourceFields.get(i));handledFields.add((double) (labels.indexOf(sourceFields.get(i))));} else {handledFields.add((double)(index));}}else {handledFields.add(Double.parseDouble(sourceFields.get(i)));}}return this;}//求每一个特征的最大值和最小值public Point MaxMinValue(Point point){if(fieldsMaxValue.size() == 0){fieldsMaxValue.addAll(handledFields);}if(fieldsMinValue.size() == 0){fieldsMinValue.addAll(handledFields);}if(point.fieldsMaxValue.size() == 0){point.fieldsMaxValue.addAll(point.handledFields);}if(point.fieldsMinValue.size() == 0){point.fieldsMinValue.addAll(point.handledFields);}//求两个数据点各个特征值的最大值和最小值for(int i = 0; i< handledFields.size(); i++){if(point.fieldsMaxValue.get(i) > this.fieldsMaxValue.get(i)){fieldsMaxValue.set(i,point.fieldsMaxValue.get(i));}if(point.fieldsMinValue.get(i) < this.fieldsMinValue.get(i)){fieldsMinValue.set(i,point.fieldsMinValue.get(i));}}return this;}//归一化public Point standardHandled(Point point){for(int i = 0; i< handledFields.size(); i++){double max = point.fieldsMaxValue.get(i);double min = point.fieldsMinValue.get(i);double value = handledFields.get(i);handledFields.set(i, max == min ? min : (value - min) / (max - min));}return this;}//加法器public Point add(Point other){//特征值相加for (int i = 0; i < handledFields.size(); i++) {handledFields.set(i,handledFields.get(i) + other.handledFields.get(i));}//数据点个数相加number += other.number;return this;}//除法器public Point div(long val){for (int i = 0; i < handledFields.size(); i++) {handledFields.set(i,handledFields.get(i) / val);}return this;}//计算两点之间的欧式距离public double euclideanDistance(Point other){double sum = 0;for (int i = 0; i < handledFields.size(); i++) {sum += Math.pow((handledFields.get(i) - other.handledFields.get(i)),2);}return Math.sqrt(sum);}@Overridepublic String toString() {return "Point{" +"sourceFields=" + sourceFields +'}';}}

聚类中心

package cn.xsy.algorithm.kmeans;import java.io.Serializable;/*** 簇中心*/public class Cluster implements Serializable {//簇idpublic int id;//簇中心点public Point centre;public Cluster(int id, Point centre) {this.id = id;this.centre = centre;}public Cluster() {}@Overridepublic String toString() {return "Cluster{" +"id=" + id +", centre=" + centre +'}';}}

kmeans常量

package cn.xsy.algorithm.kmeans;public final class KMeansConstant {//入侵检测数据点文件//0,tcp,http,SF,228,896,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,23,24,0.00,0.00,0.00,0.00,1.00,0.00,0.08,255,255,1.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,normal.public static final String POINTFILE = "C:\\Users\\xsy\\Desktop\\KDD99入侵检测数据集\\Test data with corrected labels\\corrected\\corrected";//结果输出目录public static final String OUTPUTPATH = "C:\\Users\\xsy\\Desktop\\KDD99入侵检测数据集\\output";//最大迭代次数public static final int MAXITERATIONS = 10;//迭代终止条件,每次迭代前后,簇中心的距离差public static final double DISDIFF = 1.0E-13;//K值,即簇的个数public static final int KNUM = 2;/** 特征转换数据相关 **///协议类型public static final String[] PROTOCOLS = new String[]{"tcp","udp","icmp"};//目标主机的网络服务类型public static final String[] SERVICES = new String[]{"aol","auth","bgp","courier","csnet_ns","ctf","daytime","discard","domain","domain_u","echo","eco_i","ecr_i","efs","exec","finger","ftp","ftp_data","gopher","harvest","hostnames","http","http_2784","http_443","http_8001","imap4","IRC","iso_tsap","klogin","kshell","ldap","link","login","mtp","name","netbios_dgm","netbios_ns","netbios_ssn","netstat","nnsp","nntp","ntp_u","other","pm_dump","pop_2","pop_3","printer","private","red_i","remote_job","rje","shell","smtp","sql_net","ssh","sunrpc","supdup","systat","telnet","tftp_u","tim_i","time","urh_i","urp_i","uucp","uucp_path","vmnet","whois","X11","Z39_50"};//连接正常或错误的状态public static final String[] FLAGS = new String[]{"OTH","REJ","RSTO","RSTOS0","RSTR","S0","S1","S2","S3","SF","SH"};//标识类型public static final String[] LABELS = new String[]{"normal.", "buffer_overflow.", "loadmodule.", "perl.", "neptune.", "smurf.","guess_passwd.", "pod.", "teardrop.", "portsweep.", "ipsweep.", "land.", "ftp_write.","back.", "imap.", "satan.", "phf.", "nmap.", "multihop.", "warezmaster.", "warezclient.","spy.", "rootkit.","mscan.", "saint.", "apache2.", "mailbomb.", "processtable.", "udpstorm.", "httptunnel.", "ps.","sqlattack.", "xterm.", "named.", "sendmail.", "snmpgetattack.", "snmpguess.", "worm.", "xlock.", "xsnoop."};}

KMeans主程序入口

package cn.xsy.algorithm.kmeans;import org.apache.mon.JobExecutionResult;import org.apache.mon.accumulators.IntCounter;import org.apache.mon.functions.MapFunction;import org.apache.mon.functions.ReduceFunction;import org.apache.mon.functions.RichFilterFunction;import org.apache.mon.functions.RichMapFunction;import org.apache.flink.api.java.DataSet;import org.apache.flink.api.java.ExecutionEnvironment;import org.apache.flink.api.java.aggregation.Aggregations;import org.apache.flink.api.java.operators.DataSource;import org.apache.flink.api.java.operators.IterativeDataSet;import org.apache.flink.api.java.tuple.Tuple2;import org.apache.flink.api.java.tuple.Tuple3;import org.apache.flink.api.java.utils.ParameterTool;import org.apache.flink.configuration.Configuration;import org.apache.flink.core.fs.FileSystem;import java.util.*;/*** KMeans主程序入口*/public class KMeans {public static void main(String[] args) throws Exception {//解析命令行参数ParameterTool params = ParameterTool.fromArgs(args);//构建执行环境ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();//使参数在web界面可用env.getConfig().setGlobalJobParameters(params);//从提供的文件路径读取数据点DataSource<String> sourcePoints = getPointDataSet(params, env);//从数据点中随机选取簇中心点List<String> pointCollect = sourcePoints.collect();Set<String> sourceClusterSet = getSourceClusterCollection(pointCollect, params);DataSource<String> sourceCluster = env.fromCollection(sourceClusterSet);//对数据点进行字符型特征装换DataSet<Point> featurePoints = sourcePoints.map(new FeatureHandledPoint());//对数据点求每一个特征的最大值和最小值DataSet<Point> maxMinPoint = featurePoints.reduce(new MaxMinHandledPoint());//对数据点进行归一化处理DataSet<Point> points = featurePoints.map(new StandardHandledPoint()).withBroadcastSet(maxMinPoint, "maxMinPoint");//对簇中心数据进行字符型特征装换以及归一化处理DataSet<Cluster> clusters = sourceCluster.map(new HandledCluster()).withBroadcastSet(maxMinPoint, "maxMinPoint");//设置KMeans的最大迭代次数IterativeDataSet<Cluster> loop = clusters.iterate(params.getInt("maxIterations", KMeansConstant.MAXITERATIONS));//KMeans迭代过程DataSet<Cluster> newClusters = points//将每个数据点分配到最近的簇中心.map(new SelectNearestCluster()).withBroadcastSet(loop, "clusters")//每个簇内的点坐标求和以及点个数求和.groupBy(0).reduce(new ClusterAccumulator())//计算新的簇中心.map(new ClusterAverager());//迭代终止条件DataSet<Tuple2<Cluster, Cluster>> termination = loop//将每次迭代前后的簇中心连接起来.join(newClusters).where("id").equalTo("id")//根据每次迭代前后簇中心的距离差过滤簇中心.filter(new TerminationCriterion());//将新的簇中心数据反馈到下一个迭代中DataSet<Cluster> finalClusters = loop.closeWith(newClusters, termination);//将point分派到最后的簇中DataSet<Tuple2<Integer, Point>> clusterPoints = points.map(new SelectNearestCluster()).withBroadcastSet(finalClusters, "clusters");//统计每一个簇中心每一个labels的point个数// DataSet<Tuple3<Integer, String, Long>> clusterLabelsCount = clusterPoints.map(new CountClusterLabels()).groupBy(0,1).aggregate(Aggregations.SUM, 2);//统计每一个簇中心的point个数DataSet<Tuple2<Integer, Long>> clusterCount = clusterPoints.map(new CountCluster()).groupBy(0).aggregate(Aggregations.SUM, 1);// 输出结果String outputPath = params.has("outputPath") ? params.get("outputPath") : KMeansConstant.OUTPUTPATH;clusterPoints.writeAsText(outputPath, FileSystem.WriteMode.OVERWRITE);clusterCount.print();// env.execute("KDD CUP99 KMeans");//一些统计结果以及算KMeans的Purity指数JobExecutionResult lastJobExecutionResult = env.getLastJobExecutionResult();double purity = getPurity(lastJobExecutionResult);System.out.println("purity: " + purity);}/*** 计算KMeans的Purity指数** @param lastJobExecutionResult* @return*/private static double getPurity(JobExecutionResult lastJobExecutionResult) {//数据点总数int pointCount = lastJobExecutionResult.getAccumulatorResult("pointCount");//簇中心1的正常点int cluster1Normal = lastJobExecutionResult.getAccumulatorResult("cluster1Normal");//簇中心1的异常点int cluster1Abnormal = lastJobExecutionResult.getAccumulatorResult("cluster1Abnormal");//簇中心2的正常点int cluster2Normal = lastJobExecutionResult.getAccumulatorResult("cluster2Normal");//簇中心2的异常点int cluster2Abnormal = lastJobExecutionResult.getAccumulatorResult("cluster2Abnormal");double purity;if(cluster1Abnormal > cluster2Abnormal){purity = (double) (cluster1Abnormal + cluster2Normal) / pointCount;} else if(cluster1Abnormal < cluster2Abnormal){purity = (double) (cluster2Abnormal + cluster1Normal) / pointCount;}else {if(cluster1Normal > cluster2Normal){purity = (double) (cluster2Abnormal + cluster1Normal) / pointCount;}else {purity = (double) (cluster1Abnormal + cluster2Normal) / pointCount;}}System.out.println("数据点总个数: " + pointCount);System.out.println("簇中心1正常点个数: " + cluster1Normal);System.out.println("簇中心1异常点个数: " + cluster1Abnormal);System.out.println("簇中心2正常点个数: " + cluster2Normal);System.out.println("簇中心2异常点个数: " + cluster2Abnormal);return purity;}/*** 得到输入点数据集** @param params* @param env* @return*/private static DataSource<String> getPointDataSet(ParameterTool params, ExecutionEnvironment env) {String pointFile = params.has("pointFile") ? params.get("pointFile") : KMeansConstant.POINTFILE;DataSource<String> sourcePoints = env.readTextFile(pointFile);return sourcePoints;}/*** 从数据点中随机选取簇中心点,构建簇中心数据集** @param sourcePointList* @param params* @return*/private static Set<String> getSourceClusterCollection(List<String> sourcePointList, ParameterTool params) {int kNum = params.has("kNum") ? Integer.parseInt(params.get("kNum")) : KMeansConstant.KNUM;Set<String> clusterSet = new HashSet<String>();Random random = new Random();for (int id = 1; id <= kNum; ) {String point = sourcePointList.get(random.nextInt(sourcePointList.size()));//用于标记是否已经选择过该数据boolean flag =true;for (String cluster : clusterSet) {String[] split = cluster.split(" ");if (split[0].equals(point)) {flag = false;}}//如果随机选取的点没有被选中过,则加入到SET中if (flag) {String cluster = point + " " + id;clusterSet.add(cluster);System.out.println("簇中心" + id + ": " + cluster);id++;}}return clusterSet;}/*** 对数据点进行字符型特征装换*/public static final class FeatureHandledPoint implements MapFunction<String, Point> {public Point map(String s) throws Exception {String[] split = s.split(",");Point point = new Point(Arrays.asList(split));//字符型特征转换为数值型特征Point featurePoint = point.featureHandled();return featurePoint;}}/*** 对数据点求每个特征的最大值和最小值*/public static final class MaxMinHandledPoint implements ReduceFunction<Point> {public Point reduce(Point p1, Point p2) throws Exception {//求每一个特征的最大值和最小值return p1.MaxMinValue(p2);}}/*** 对簇中心数据进行字符型特征装换、归一化处理*/public static final class HandledCluster extends RichMapFunction<String, Cluster> {private List<Point> maxMinPoints;@Overridepublic void open(Configuration parameters) throws Exception {this.maxMinPoints = getRuntimeContext().getBroadcastVariable("maxMinPoint");}public Cluster map(String s) throws Exception {String[] fields = s.split(" ");String[] splits = fields[0].split(",");Point centre = new Point(Arrays.asList(splits));//字符型特征转换为数值型特征Point featureCentre = centre.featureHandled();//归一化Point standardCentre = featureCentre.standardHandled(maxMinPoints.get(0));return new Cluster(Integer.parseInt(fields[1]), standardCentre);}}/*** 对数据点进行归一化处理* X(norm) = (X - min) / (max - min)*/public static final class StandardHandledPoint extends RichMapFunction<Point, Point> {//point条数private IntCounter pointCount = new IntCounter();private List<Point> maxMinPoints;@Overridepublic void open(Configuration parameters) throws Exception {getRuntimeContext().addAccumulator("pointCount", pointCount);this.maxMinPoints = getRuntimeContext().getBroadcastVariable("maxMinPoint");}public Point map(Point point) throws Exception {//对每一个point进行归一化Point standardPoint = point.standardHandled(maxMinPoints.get(0));pointCount.add(1);return standardPoint;}}/*** 对每一个数据点,找到距离最近的簇中心*/public static final class SelectNearestCluster extends RichMapFunction<Point, Tuple2<Integer, Point>> {private Collection<Cluster> clusters;@Overridepublic void open(Configuration parameters) throws Exception {this.clusters = getRuntimeContext().getBroadcastVariable("clusters");}@Overridepublic Tuple2<Integer, Point> map(Point point) throws Exception {double minDistance = Double.MAX_VALUE;int closestClusterId = -1;for (Cluster cluster : clusters) {double distance = point.euclideanDistance(cluster.centre);if (distance < minDistance) {minDistance = distance;closestClusterId = cluster.id;}}return new Tuple2<Integer, Point>(closestClusterId, point);}}/*** 对每一个簇内点计数以及对簇内点的坐标进行累加*/public static final class ClusterAccumulator implements ReduceFunction<Tuple2<Integer, Point>> {public Tuple2<Integer, Point> reduce(Tuple2<Integer, Point> val1, Tuple2<Integer, Point> val2) {// 对簇内点坐标累加,然后对簇内元素个数计数return new Tuple2<Integer, Point>(val1.f0, val1.f1.add(val2.f1));}}/*** 从簇内点的个数和这些点的坐标和计算出新的簇中心*/public static final class ClusterAverager implements MapFunction<Tuple2<Integer, Point>, Cluster> {public Cluster map(Tuple2<Integer, Point> value) {// 新的簇中心id和簇中心坐标return new Cluster(value.f0, value.f1.div(value.f1.number));}}/*** 根据每次迭代前后簇中心的距离差过滤簇中心*/public static final class TerminationCriterion extends RichFilterFunction<Tuple2<Cluster, Cluster>> {public boolean filter(Tuple2<Cluster, Cluster> value) throws Exception {ParameterTool params = (ParameterTool) getRuntimeContext().getExecutionConfig().getGlobalJobParameters();double disDiff = params.has("disDiff") ? Double.parseDouble(params.get("disDiff")) : KMeansConstant.DISDIFF;double moveDistance = value.f0.centre.euclideanDistance(value.f1.centre);System.out.println("簇中心" + value.f0.id + "移动距离: " + moveDistance);return moveDistance > disDiff;}}/*** 将Tuple2<Integer, Point>转换为 Tuple3<Integer, String, Long>*/public static final class CountClusterLabels implements MapFunction<Tuple2<Integer, Point>, Tuple3<Integer, String, Long>> {public Tuple3<Integer, String, Long> map(Tuple2<Integer, Point> integerPointTuple2) throws Exception {//对每一个簇,每一个LABELS的点进行计数return new Tuple3<Integer, String, Long>(integerPointTuple2.f0, integerPointTuple2.f1.sourceFields.get(41), 1L);}}/*** 将Tuple2<Integer, Point>转换为 Tuple2<Integer, Long>*/public static final class CountCluster extends RichMapFunction<Tuple2<Integer, Point>, Tuple2<Integer, Long>> {//簇中心1的正常点private IntCounter cluster1Normal = new IntCounter();//簇中心1的异常点private IntCounter cluster1Abnormal = new IntCounter();//簇中心2的正常点private IntCounter cluster2Normal = new IntCounter();//簇中心2的异常点private IntCounter cluster2Abnormal = new IntCounter();@Overridepublic void open(Configuration parameters) throws Exception {getRuntimeContext().addAccumulator("cluster1Normal", cluster1Normal);getRuntimeContext().addAccumulator("cluster1Abnormal", cluster1Abnormal);getRuntimeContext().addAccumulator("cluster2Normal", cluster2Normal);getRuntimeContext().addAccumulator("cluster2Abnormal", cluster2Abnormal);}public Tuple2<Integer, Long> map(Tuple2<Integer, Point> t2) throws Exception {if (t2.f0 == 1) {if ("normal.".equals(t2.f1.sourceFields.get(41))) {cluster1Normal.add(1);} else {cluster1Abnormal.add(1);}} else if (t2.f0 == 2) {if ("normal.".equals(t2.f1.sourceFields.get(41))) {cluster2Normal.add(1);} else {cluster2Abnormal.add(1);}}//对每一个簇内的点进行计数return new Tuple2<Integer, Long>(t2.f0, 1L);}}}

7、输出结果

簇中心1: 0,tcp,http,SF,236,314,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,4,4,0.00,0.00,0.00,0.00,1.00,0.00,0.00,255,255,1.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,normal. 1簇中心2: 0,icmp,ecr_i,SF,1032,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,120,120,0.00,0.00,0.00,0.00,1.00,0.00,0.00,255,255,1.00,0.00,1.00,0.00,0.00,0.00,0.00,0.00,smurf. 2簇中心2移动距离: 0.826594231159133簇中心1移动距离: 1.2939328884752181簇中心1移动距离: 0.04597443790175165簇中心2移动距离: 0.06020607832667806簇中心2移动距离: 0.0022798932582670174簇中心1移动距离: 0.0025190908117799595簇中心1移动距离: 2.0983487049015966E-4簇中心2移动距离: 1.841180332519695E-4簇中心2移动距离: 1.5844853826986143E-5簇中心1移动距离: 1.808769339554904E-5簇中心1移动距离: 2.1330751274520463E-16簇中心2移动距离: 1.1775693753296206E-16(1,145222)(2,165807)数据点总个数: 311029簇中心1正常点个数: 59337簇中心1异常点个数: 85885簇中心2正常点个数: 1256簇中心2异常点个数: 164551purity: 0.7198299836992692

8、参考

/asialee_bird/article/details/80491256

/hxcaifly/article/details/86496243

如果觉得《基于flink使用K-Means算法对KDD CUP99数据集进行聚类分析》对你有帮助,请点赞、收藏,并留下你的观点哦!

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。