对adult数据集建立朴素贝叶斯模型评估并可视化(awk+hive+java+mysql+echarts)

论坛 期权论坛 脚本     
匿名网站用户   2020-12-19 12:57   803   0

总结下这段时间的收获
分为三部分
linux
echarts
数据挖掘


最近做了个课设,要给adult数据集建立合适的分类器,并对分类器评估

分享一下成果,有好多问题完成后才知道,仅供和我一样的初学者参考


adult数据集,从uci官网给的信息

Data Set Characteristics:

Multivariate

Number of Instances:

48842

Area:

Social

Attribute Characteristics:

Categorical, Integer

Number of Attributes:

14

Date Donated

1996-05-01

Associated Tasks:

Classification

Missing Values?

Yes

Number of Web Hits:

913044

Listing of attributes:

>50K, <=50K.

age: continuous.
workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.
fnlwgt: continuous.
education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.
education-num: continuous.
marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.
occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.
relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.
race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.
sex: Female, Male.
capital-gain: continuous.
capital-loss: continuous.
hours-per-week: continuous.
native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.



了解到这个数据集各属性是这个样子的

字段名

含义

类型

age

年龄

Double

workclass

工作类型*

string

fnlwgt

序号

string

education

教育程度*

string

education_num

受教育时间

double

maritial_status

婚姻状况*

string

occupation

职业*

string

relationship

关系*

string

race

种族*

string

sex

性别*

string

capital_gain

资本收益

string

capital_loss

资本损失

string

hours_per_week

每周工作小时数

double

native_country

原籍*

string

income

收入

string


且有残缺的数据,用?表示

首先要进行数据清理,用linux的awk进行了对残缺数据的查找和处理

#!/bin/sh

infile=$1

outfile=$2

awk -F ", "'BEGIN{id=1;num=1;num2=1}{

tf=1

for(i=1;i<=14;i++){

if($i=="?"){

tf=0

printid","i,","$i","num","num2

num=num+1

}

}

id=id+1

if(tf){

#print$1"\t"$2\t"$3\t"$4\t"$5\t"$6\t"$7\t"$8\t"$9\t"$10\t"$11\t"$12\t"$13\t"$14\t"$15

}else{

num2=num2+1

}

}' $infile > $outfile


sum1统计?出现次数,sum2统计有多少行存在?

统计出num1=4262,num2=2399
我们拿到的数据集数据量和官网给出的不一样,应该是老版本的原因

考虑到样本数量基数大有32561行,缺失样本数据只占7%相对较小,所以采用剔除缺失数据的处理方法

#!/bin/sh

infile=$1

outfile=$2

awk -F ", " '{

tf=1

for(i=1;i<=14;i++){

if($i=="?"){

tf=0

}

}

if(tf){

print $1"\t"$2\t"$3\t"$4\t"$5\t"$6\t"$7\t"$8\t"$9\t"$10\t"$11\t"$12\t"$13\t"$14\t"$15

}

}' $infile > $outfile


原始数据有14个属性变量,我们先进行相关分析选出和居民收入相关性最大的七个属性:Age、Workclass、Education、Occupation、race、Sex、Native_country

类别型有:Workclass education occupation race sex native_country 6个

连续型有:Age

所以统一为类别型表示

Age样本数据17-90分为8类:0-20 20-30 30-40 40-50 50-60 60-70 70-80 >=80

Workclass分为8类:Private, Self-emp-not-inc,Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked

Education分为16类:Bachelors, Some-college,11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters,1st-4th, 10th, Doctorate, 5th-6th, Preschool

Occupation分为14类:{Tech-support, Craft-repair,Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners,Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving,Priv-house-serv, Protective-serv, Armed-Forces}

Race分为5类:White, Asian-Pac-Islander,Amer-Indian-Eskimo, Other, Black

Sex分为2类:Female, Male

Native_country分为41类:United-States, Cambodia, England,Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece,South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica,Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador,Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand,Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands

类别型比较多,采用贝叶斯模型

因为要评估模型,所以,在样本集里选了一个大样本做训练集,一个小样本做预测集

#!/bin/sh

infile=$1

outfile=$2

awk 'BEGIN{i=0}{

if(i<10000){

if(10<$1&&$1<=20)

$1="<20"

else

if(20<$1&&$1<=30)

$1="20-30"

else

if(30<$1&&$1<=40)

$1="30-40"

else

if(40<$1&&$1<=50)

$1="40-50"

else

if(50<$1&&$1<=60)

$1="50-60"

else

if(60<$1&&$1<=70)

$1="60-70"

else

if(70<$1&&$1<=80)

$1="70-80"

else

if(80<$1&&$1<=90)

$1=">80"

print$1"\t"$2"\t"$4"\t"$7"\t"$9"\t"$10"\t"$14"\t"$15

}

i=i+1

}' $infile > $outfile


#!/bin/sh

infile=$1

outfile=$2

awk 'BEGIN{i=0}{

if(i>30000){

if(10<$1&&$1<=20)

$1="<20"

else

if(20<$1&&$1<=30)

$1="20-30"

else

if(30<$1&&$1<=40)

$1="30-40"

else

if(40<$1&&$1<=50)

$1="40-50"

else

if(50<$1&&$1<=60)

$1="50-60"

else

if(60<$1&&$1<=70)

$1="60-70"

else

if(70<$1&&$1<=80)

$1="70-80"

else

if(80<$1&&$1<=90)

$1=">80"

print$1"\t"$2"\t"$4"\t"$7"\t"$9"\t"$10"\t"$14"\t"$15

}

i=i+1

}' $infile > $outfile

生成两个数据集adult和adult_ceshi



考虑到数据量较大,所以决定先把数据用hive分析出贝叶斯模型需要的频数,再导到mysql,用java建模评估


先把数据集导入hive里,分别是adult和adult_ceshi,然后对每个属性进行分析
hive分析部分代码

CREATE TABLEshixun.age(type STRING,income STRING,amount INT) ROW FORMAT DELIMITED FIELDSTERMINATED BY '\t';

INSERT INTOTABLE shixun.age SELECT age,income,COUNT(income) FROM adult group byage,income;

CREATE TABLEworkclass LIKE age;

INSERT INTOTABLE shixun.workclass SELECT workclass,income,COUNT(income) FROM adult groupby workclass,income;

CREATE TABLEeducation LIKE age;

INSERT INTOTABLE shixun.education SELECT education,income,COUNT(income) FROM adult groupby education,income;

CREATE TABLEoccupation LIKE age;

INSERT INTOTABLE shixun.occupation SELECT occupation,income,COUNT(income) FROM adult groupby occupation,income;

CREATE TABLErace LIKE age;

INSERT INTOTABLE shixun.race SELECT race,income,COUNT(income) FROM adult group by race,income;

CREATE TABLEsex LIKE age;

INSERT INTOTABLE shixun.sex SELECT sex,income,COUNT(income) FROM adult group bysex,income;

CREATE TABLEnative_country LIKE age;

INSERT INTOTABLE shixun.native_country SELECT native_country,income,COUNT(income) FROMadult group by native_country,income;



在mysql建表,略

用sqoop导入mysql

sqoop export--connect jdbc:mysql://localhost:3306/shixun --username root --password strongs--table age --fields-terminated-by '\t' --export-dir '/user/hive/warehouse/shixun.db /age'

sqoop export--connect jdbc:mysql://localhost:3306/shixun --username root --password strongs--table workclass --fields-terminated-by '\t' --export-dir'/user/hive/warehouse/ shixun.db /workclass'

sqoop export--connect jdbc:mysql://localhost:3306/shixun --username root --password strongs--table education --fields-terminated-by '\t' --export-dir'/user/hive/warehouse/ shixun.db /education'

sqoop export--connect jdbc:mysql://localhost:3306/shixun --username root --password strongs--table occupation --fields-terminated-by '\t' --export-dir'/user/hive/warehouse/ shixun.db /occupation'

sqoop export--connect jdbc:mysql://localhost:3306/shixun --username root --password strongs--table race --fields-terminated-by '\t' --export-dir '/user/hive/warehouse/shixun.db /race'

sqoop export--connect jdbc:mysql://localhost:3306/shixun --username root --password strongs--table sex --fields-terminated-by '\t' --export-dir '/user/hive/warehouse/shixun.db /sex'

sqoop export--connect jdbc:mysql://localhost:3306/shixun --username root --password strongs--table native_country --fields-terminated-by '\t' --export-dir'/user/hive/warehouse/shixun.db/native_country'






然后在java建模分析

public class Adult {

    staticString[]str={"age","workclass","education","occupation","race","sex","native_country"};

    staticHashMap<String, HashMap> hasha= new HashMap<String,HashMap>();//<=50 shuxing

    staticHashMap<String, HashMap> hashb= new HashMap<String,HashMap>();//>50 shuxing //先把mysql结果存到内存,结果只有两类,为了方便读取干脆分开存到两个容器

    static doublezong=0;//zongshu

    static doublea = 0;//<=50

    static doubleb=0;//>50

    //////////////

    static intTP=0,TN=0,FP=0,FN=0;//tp,tn,fp,fn

  //duqu

    public staticvoid start(Connection conn) throws ClassNotFoundException  {

try{

    String sql ="select * from age";

    ResultSet set= null;

    Statementstmt = null;

    stmt =conn.createStatement();

    set =stmt.executeQuery(sql);

    while(set.next()) {

        zong=zong+set.getInt("amount");   

        if(set.getString("income").equals("<=50K")){

            a=a+set.getInt("amount");

        }else{

            b=b+set.getInt("amount");

        }

    }//只需要一个表就能求出

    for(Stringst:str){

        Stringsql1 = "select * from "+st;    

        HashMap<String,Integer> hash1= new HashMap<String, Integer>();

        HashMap<String,Integer> hash2= new HashMap<String, Integer>();

            set =stmt.executeQuery(sql1);    

            while(set.next()) {              

                if(set.getString("income").equals("<=50K")){

                  

                   hash1.put(set.getString("type"),set.getInt("amount"));

                }else{     

                   hash2.put(set.getString("type"),set.getInt("amount"));

                }      

            }

            hasha.put(st,hash1);

            hashb.put(st,hash2); 

    }} catch(SQLException e) {

            System.err.println(e.getMessage());

        }

  StringBuffersb =new StringBuffer();

    sb.append(hasha+"\n");

    sb.append(hashb);

    System.out.println(sb);

    System.out.println(zong+","+a+","+b);//打印内存中存好的数据

  }

  //yuce

  private staticvoid Second(Connection conn) {

        Stringsql="select * from adult_ceshi";

        ResultSetset = null;

        Statementstmt = null;

        try {

            stmt= conn.createStatement();

            set =stmt.executeQuery(sql);

            while(set.next()) {

double p1 = 1,p2=1;

 

    for(Stringst:Adult.str){//表内属性循环查找

       

    double sum1 =0,sum2=0;

     if(hasha.get(st).get(set.getString(st))==null)

         sum1=0;

     else

     sum1=(int) hasha.get(st).get(set.getString(st));

     p1=(sum1/Adult.a)*p1;

     if(hashb.get(st).get(set.getString(st))==null)

         sum2=0;

     else

     sum2=(int)hashb.get(st).get(set.getString(st));

     p2=(sum2/Adult.b)*p2;

    }  

    doublepp1=p1*(a/zong);

    doublepp2=p2*(b/zong);

    booleanresult1=pp1>=pp2;//预测结果

    Stringre=">50K";

    if(result1)

    re="<=50K";

    booleanresult2=false;//真正结果

    if(set.getString("income").equals("<=50K"))

        result2=true;

    if(result1)

        if(result1&&result2)

            TP++;

            else

            FP++;

    else

        if(result1||result2)

            FN++;

        else

            TN++;

            }  

            String sql2 = "truncate tablebayes";

            String sql3 = "insert intobayesvalues("+TP+","+TN+","+FP+","+FN+")";//评估结果

           

            stmt.executeUpdate(sql2);

            stmt.executeUpdate(sql3);//add bayes

        }                 

         catch (SQLException e) {

            System.err.println(e.getMessage());

        }

       

        System.out.println(TP+","+TN+","+FP+","+FN);         

  }    

  public static void main(String[] args) throwsClassNotFoundException {

     String driver   ="com.mysql.jdbc.Driver";

        String url      = "jdbc:mysql://localhost:3306/shixun?zeroDateTimeBehavior=convertToNull&characterEncoding=utf8";

        String user     = "root";

        String password = "strongs";

        Connection conn = null;

        Class.forName(driver);        

 

        try {

           

            conn =DriverManager.getConnection(url, user, password);

        } catch (SQLException e) {

            System.err.println(e.getMessage());

        }

  start(conn);//读分析结果到内存

  Second(conn);//建模分析

  } 

}




最后对结果进行可视化


这里第一张表用了echarts2.0,因为2.0特有的拖拽重计算很好用,后面用了echarts3.0


后台

public classBayes extends HttpServlet {

   publicvoid doGet(HttpServletRequest request, HttpServletResponse response)//在这是重写doGet方法

            throwsServletException {

                response.setContentType("text/html;charset=utf-8");

               

           

                Stringdriver   ="com.mysql.jdbc.Driver";

                Stringurl      ="jdbc:mysql://localhost:3306/shixun?zeroDateTimeBehavior=convertToNull&characterEncoding=utf8&useSSL=true";

                Stringuser     = "root";

                Stringpassword = "123456";

                Connectionconn = null;

 

                           try {

                               Class.forName(driver);

                           } catch(ClassNotFoundException e2) {

                               // TODO 自动生成的 catch 块

                               e2.printStackTrace();

                           }

 

                  

                  

                           try {

                               conn =DriverManager.getConnection(url, user, password);

                           } catch (SQLExceptione1) {

                               // TODO 自动生成的 catch 块

                               e1.printStackTrace();

                           }        

                Stringsql = "select * from bayes";

               

                ResultSetset = null;

                Statementstmt = null;

                List<JSONObject>list=new ArrayList<JSONObject>();

                JSONArrayarray=new JSONArray();

                try{

                   stmt= conn.createStatement();

                   set= stmt.executeQuery(sql);

                  

                   while(set.next()) {

                       String[]str={"tp","tn","fp","fn"};

                       for(Stringst:str){

                       JSONObjectjson=new JSONObject(); 

                       json.put("name",st);

                       json.put("value",set.getInt(st)); 

                       array.add(json);

                       }

                      

                   }

                }catch (SQLException e) {

                   System.err.println(e.getMessage());

                }          

                   PrintWriterout;

                   try {

                       out =response.getWriter();

                       out.print(array);

                       out.flush();

                       out.close();

                   } catch (IOException e) {

                       // TODO 自动生成的 catch 块

                       e.printStackTrace();

                   }

   }

   publicvoid doPost(HttpServletRequest request, HttpServletResponse response)

            throws ServletException, IOException {

        doGet(request,response);

   }

   }



后台用了json对象做容器,存成name:value方便读取


前台

<htmllang="en">

 

<head>

<metacharset="utf-8" />

<metahttp-equiv="X-UA-Compatible" content="IE=edge,chrome=1"/>

<metaname="viewport" content="width=device-width, initial-scale=1.0,maximum-scale=1.0, user-scalable=no">

<title>数据挖掘可视化</title>

<linkrel="stylesheet" href="style/css/jquery.fullPage.css" />

<linkrel="stylesheet" href="style/css/base.css" />

<linkrel="stylesheet" href="style/css/welcome.css" />

    <script type="text/javascript"src="js/echarts-all.js"></script>

<scriptsrc="style/js/jquery-1.8.3.min.js"></script>

<scripttype="text/javascript" src="style/js/jquery.fullPage.min.js"></script>

<scripttype="text/javascript"src="style/js/jquery.bxslider.js"></script>

<scripttype="text/javascript"src="style/js/main.js"></script>

<scripttype="text/javascript">

$(document).ready(function(){

        my_section1.init();

  $('#welcome').fullpage({

    'verticalCentered': false,

    'css3': true,

    'sectionsColor': ['#6cbb52','#e89c38','#40a3e1'],

    'navigation': true,

    'navigationPosition': 'right',    

  });

});

</script>

</head>

<body>

    <div id="welcome">

        <div class="section"id="section0">     

         <div class="my_section">                 

               <div id="a"style="width: 100%; height: 70%; float:left;"></div>

            </div>

        </div>

        <div class="section"id="section1">

            <div class="my_section">                 

               <div id="b"style="width: 50%; height: 70%; float:left;"></div>

               <div id="c"style="width: 50%; height: 70%; float:left;"></div>

            </div>

        </div>

        <div class="section"id="section2">

            <divclass="my_section">                 

              <div id="d"style="width: 50%; height: 70%; float:left;"></div>

              <div id="e"style="width: 50%; height: 70%; float:left;"></div>

            </div>

        </div>

    </div>

    <div class="fixed head">

        <divclass="head_content">

            <div class="icon logologo_blue fl" style="width: 193px; height: 46px"></div>

            <div class="frlogin_box">                             

                <span class="">专注数据挖掘</span>                         

            </div>

        </div>

        <divclass="line"></div>

    </div>

  </body>

           <scripttype="text/javascript">

    var a=new Array();

    $.ajax({   

type:"post",   

url:"./Bayes", 

async:false,

success:function (datas) { 

//alert(datas); 

a=JSON.parse(datas);

}

});

//////////////////////////////////

var color =["#DC143C","#191970","#FF6347","#808080"];

functionjo(a,b,c){

return {

name:a,

value:b,

itemStyle:{

    normal:{color:c}

           }

}

}

var a1=newArray();

for(i=0;i<a.length;i++){

a1.push(jo(a[i].name,a[i].value,color[i]));

}

//alert(b);

//////////////////////////////////

functionserise(na,va){

return {

name:na,

value:va

};}

/////////////////////////////////

tu=0;

fl=0;

for(i=0;i<a.length;i++){

if(a[i].name=="tp"||a[i].name=="tn")

//alert(a[i].value);

tu=tu+a[i].value;

else

fl=fl+a[i].value

}

var b=newArray();

b.push(serise("准确数",tu));

b.push(serise("错误数",fl));

//alert(b);

///////////////////////////////////

lg=0;

sm=0;

for(i=0;i<a.length;i++){

if(a[i].name=="tp"||a[i].name=="fp")

sm=sm+a[i].value;

else

lg=lg+a[i].value

}

var c=newArray();

c.push(serise(">50K",lg));

c.push(serise("<=50K",sm));

//alert(c);

/////////////////////////////////

lg=0;

sm=0;

for(i=0;i<a.length;i++){

if(a[i].name=="tp"||a[i].name=="tn")

sm=sm+a[i].value;

else

lg=lg+a[i].value

}

var d=newArray();

d.push(sm);

d.push(lg);

//alert(d);

/////////////////////////////////

lg=0;

sm=0;

for(i=0;i<a.length;i++){

if(a[i].name=="tp"||a[i].name=="fp")

sm=sm+a[i].value;

else

lg=lg+a[i].value

}

var e=newArray();

e.push(sm);

e.push(lg);

//alert(e);

/////////////////////////////////

option1=({

    title : {

        text: 'bayes预测结果分析',

        x:'center'

    },

    tooltip : {

        trigger: 'item',

        formatter: "{a} <br/>{b} :{c} ({d}%)"

    },

    

    legend: {

        orient : 'vertical',

        x : 'left',

        data:['tp','tn','fp','fn'],

         textStyle :{color: 'auto'}

    },

    toolbox: {

        show : true,

        feature : {

            mark : {show: true},

            dataView : {show: true, readOnly:false},        

            restore : {show: true},

            saveAsImage : {show: true}

        }

    },

    calculable : true,

    series : [

        {

                name:'结果分析',

                     type:'pie',

            radius : '55%',

            center: ['50%', '60%'],

            data:a1 

        }

    ]

});

option2={

    title : {

        text: 'bayes模型准确率',

        x:'center'

    },

    tooltip : {

        trigger: 'item',

        formatter: "{a} <br/>{b} : {c}({d}%)"

    },

    series : [

        {

                name:'结果分析',

                     type:'pie',

            radius : '55%',

            center: ['50%', '60%'],

            data:b

        }

    ]

};

option3={

    title : {

        text: '两类人数比例',

        x:'center'

    },

    tooltip : {

        trigger: 'item',

        formatter: "{a} <br/>{b} :{c} ({d}%)"

    },

  

    series : [

        {

                name:'结果分析',

                     type:'pie',

            radius : '55%',

            center: ['50%', '60%'],

            data:c

        }

    ]

};

option4={

    title : {

        text: 'bayes模型准确数',

        x:'center'

    },

    tooltip : {

        trigger: 'axis'

    },

      grid: {

          left: '20%',

          right: '20%',

          width:'60%',

          bottom: '10%',

          containLabel: true,

                           

          },

      xAxis : [

        {

            type : 'category',

            data : ['准确数','错误数']

        }

    ],

    yAxis : [

        {

            type : 'value'

        }

    ],

    series : [

        {

            name:'结果分析',

            type:'bar',

            data:d,

        }

    ]

};

option5={

    title : {

        text: '两类人数对比',

        x:'center'

    },

    tooltip : {

        trigger: 'axis'

    },

      grid: {

          left: '20%',

          right: '20%',

          width:'60%',

          bottom: '10%',

          containLabel: true,                          

          },

      xAxis : [

        {

            type : 'category',

            data : ['<=50K的人','>50K的人']

        }

    ],

    yAxis : [

        {

            type : 'value'

        }

    ],

    series : [

        {

                name:'结果分析',

                       type:'bar',

            data:e,

        }

    ]

};

myChart1 =echarts.init(document.getElementById('a'));     

myChart1.setOption(option1);                             

</script>

<scripttype="text/javascript"src="js/echarts.min.js"></script>

<script>

myChart2 =echarts.init(document.getElementById('b'));

myChart2.setOption(option2);

myChart3 =echarts.init(document.getElementById('c'));

myChart3.setOption(option3);

myChart4 =echarts.init(document.getElementById('d'));

myChart4.setOption(option4);

myChart5 =echarts.init(document.getElementById('e'));

myChart5.setOption(option5);

</script>

</html>


都是些很简单的用法。。


分享到 :
0 人收藏
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

积分:1136255
帖子:227251
精华:0
期权论坛 期权论坛
发布
内容

下载期权论坛手机APP