Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
4.2k views
in Technique[技术] by (71.8m points)

python - Finding the most frequent attributes set in census dataset

I've got a problem that I'm working on involving a dataset with 12 variables in which I want to create a function with two inputs (numberOfAttributes, supportThreshold).

For example, with inout (4,.6), I'd like to retrieve all 4 attribute combos that comprise 60% of the dataset.

Here's my code:

def attributesSet(numberOfAttributes, supportThreshold):
    import csv
    import pandas as pd
    import itertools
    import math

    names = ['age','sex','education','country','race','status','workclass','occupation','hours- 
    per-week','income','capital-gain','capital-loss']
    combinations = []
    final = []
    for comb in itertools.combinations(names,numberOfAttributes):
        combinations.append(list(comb))
    c = pd.read_csv('census.csv')
    c.columns= names
    total = len(c.index)
    required = supportThreshold*total

    for i in combinations:
        g = c.groupby(i).size().sort_values(ascending=False)
        groups = g[g>required].index
        satisfied = list(groups)
        for j in satisfied:
            row = ''
            for t in j:
                row = row + t
                if j.index(t) != len(j)-1:
                    row = row + ','
            final.append(''+row)
    return final

My code works up until I change numberOfAttributes to 1, in which case my outputs have a comma inbetween each character. Does anyone know how I can fix this?


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

The problem in your code is here:

            row = ''
            for t in j:
                row = row + t
                if j.index(t) != len(j)-1:
                    row = row + ','

because j is a string when numberOfAttributes is 1, while it is a tuple with numberOfAttributes items if numberOfAttributes is greater than 1.

So, you can fix your code by changing the way row is computed, based on the type of j:

            if isinstance(j, str):
                row = j
            else:
                row = ''
                for t in j:
                    row = row + t
                    if j.index(t) != len(j)-1:
                        row = row + ','

However, you can significantly simplify your code, making it easier to read:

import pandas as pd
import itertools

def get_attributes_set(filepath, n_attributes, support_threshold):
    df = pd.read_csv(filepath)
    required = support_threshold * len(df.index)
    final = []
    for i in itertools.combinations(df.columns, n_attributes):
        g = df.groupby(list(i)).size().sort_values(ascending=False)
        satisfied = list(g[g > required].index)
        if len(satisfied):
            final.append(satisfied[0] if isinstance(satisfied[0], str) else ','.join(satisfied[0]))
    return final

Testing the previous code with the following lines:

print(get_attributes_set('census.csv', 1, .6))
print(get_attributes_set('census.csv', 4, .6))

you get:

['sex=Male', 'native-country=United-States', 'race=White', 'workclass=Private', 'income=Small', 'capital-gain=None', 'capital-loss=None']
['native-country=United-States,race=White,capital-gain=None,capital-loss=None', 'native-country=United-States,income=Small,capital-gain=None,capital-loss=None']

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...