Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
140 views
in Technique[技术] by (71.8m points)

python - Linear regression on car dehko dataset , validation loss lower than traiing

I was performing a simple regression on the car dehko dataset (Version 3, you can find it here), and I found that the validation loss is always lower than the training one.

Generally, I did some basic pre-processing. I extracted the numerical values from some columns, dropped one column (torque) and one outlier, normalized the features, dummified the categorical ones. Then, I ran the regression using keras.

This is my code

import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
print(tf.__version__)
from tensorflow import keras
from tensorflow.keras import layers

pd.options.display.max_columns = None
pd.options.display.width=None


def pre_process(df_data):
    ## check NaNs and drop rows if any
    print(df_data.isnull().sum())
    df_data.dropna(inplace=True)

    ## drop weird outlier, turns out it has 1 km_driven
    df_data.drop([7913], inplace=True)

    ## taking only first word in these columns (removing car models and units of measures)
    df_data['selling_price'] = df_data['selling_price']/1000
    df_data['name'] = df_data['name'].map(lambda x: x.split(' ')[0])
    df_data['owner'] = df_data['owner'].map(lambda x: x.split(' ')[0])
    df_data['mileage'] = (df_data['mileage'].astype(str).apply(lambda x: x.split(' ')[0])).astype(float)
    df_data['engine'] = (df_data['engine'].astype(str).apply(lambda x: x.split(' ')[0])).astype(float)
    df_data['max_power'] = (df_data['max_power'].astype(str).apply(lambda x: x.split(' ')[0])).astype(float)
    df_data.drop(['torque'], axis=1, inplace=True)


    ## dummify categorical features
    df_data = pd.get_dummies(df_data, drop_first=True)

    ## data normalization (min-max)
    print('Data normalization')
    df_data = normalize(df_data)
    return df_data

def normalize(df):
    result = df.copy()
    for feature_name in df.columns:
        ## not normalizing target
        if feature_name == 'selling_price':
            pass
        else:
            # print(f'Normalizing {feature_name}')
            result[feature_name] = (df[feature_name] - df[feature_name].min()) / (df[feature_name].max() - df[feature_name].min())
            if result[feature_name].isnull().values.any():
                result.drop([feature_name], axis=1, inplace=True)
                print(f'Something wrong in {feature_name}, dropped.')
                print(f'now shape is {len(result)}, {len(result.columns)}')
    print(f'Returning {len(result)}, {len(result.columns)}')
    return result



def build_model():

    model = keras.Sequential([
    layers.Dense(1)
     ])
  #model = keras.Sequential([
   # layers.Dense(64, activation='relu', input_shape=[len(train_dataset.keys())]),
    #layers.Dense(64, activation='relu'),
    #layers.Dense(1)
  #])

    optimizer = tf.keras.optimizers.RMSprop(0.001)

    model.compile(loss='mse',
                optimizer=optimizer,
                metrics=['mae', 'mse'])
    return model



df_data = pd.read_csv('sample_data/car_details_v3.csv')

## data pre-processing
df_data = pre_process(df_data)

X = df_data.copy()
Y = X.pop('selling_price')


model = build_model()
history = model.fit(X, Y, validation_split = 0.4, epochs=100, batch_size=500)


plt.plot(history.history['mse'])
plt.plot(history.history['val_mse'])
plt.title('model mse')
plt.ylabel('mse')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

And this is the training output:

Epoch 1/100
10/10 [==============================] - 1s 23ms/step - loss: 1153458.8352 - mae: 674.3291 - mse: 1153458.8352 - val_loss: 971128.0000 - val_mae: 623.6071 - val_mse: 971128.0000
Epoch 2/100
10/10 [==============================] - 0s 6ms/step - loss: 1215337.1818 - mae: 674.7663 - mse: 1215337.1818 - val_loss: 971033.8125 - val_mae: 623.5282 - val_mse: 971033.8125
Epoch 3/100
10/10 [==============================] - 0s 7ms/step - loss: 1184839.4545 - mae: 679.9471 - mse: 1184839.4545 - val_loss: 970946.9375 - val_mae: 623.4553 - val_mse: 970946.9375
Epoch 4/100
10/10 [==============================] - 0s 6ms/step - loss: 1077032.0000 - mae: 651.7002 - mse: 1077032.0000 - val_loss: 970861.5625 - val_mae: 623.3839 - val_mse: 970861.5625
Epoch 5/100
10/10 [==============================] - 0s 6ms/step - loss: 1143421.6080 - mae: 663.8522 - mse: 1143421.6080 - val_loss: 970777.5625 - val_mae: 623.3135 - val_mse: 970777.5625
Epoch 6/100
10/10 [==============================] - 0s 6ms/step - loss: 1116523.0341 - mae: 657.3438 - mse: 1116523.0341 - val_loss: 970693.1250 - val_mae: 623.2427 - val_mse: 970693.1250
Epoch 7/100
10/10 [==============================] - 0s 6ms/step - loss: 1108121.0341 - mae: 659.6135 - mse: 1108121.0341 - val_loss: 970609.1875 - val_mae: 623.1724 - val_mse: 970609.1875
Epoch 8/100
10/10 [==============================] - 0s 7ms/step - loss: 1186299.5682 - mae: 672.1452 - mse: 1186299.5682 - val_loss: 970526.1250 - val_mae: 623.1025 - val_mse: 970526.1250
Epoch 9/100
10/10 [==============================] - 0s 6ms/step - loss: 1272202.3295 - mae: 679.9844 - mse: 1272202.3295 - val_loss: 970442.6250 - val_mae: 623.0325 - val_mse: 970442.6250
Epoch 10/100
10/10 [==============================] - 0s 6ms/step - loss: 1146808.7159 - mae: 668.4530 - mse: 1146808.7159 - val_loss: 970359.3125 - val_mae: 622.9626 - val_mse: 970359.3125
Epoch 11/100
10/10 [==============================] - 0s 6ms/step - loss: 1097632.0682 - mae: 649.6506 - mse: 1097632.0682 - val_loss: 970274.7500 - val_mae: 622.8919 - val_mse: 970274.7500
Epoch 12/100
10/10 [==============================] - 0s 5ms/step - loss: 1177763.6705 - mae: 673.2304 - mse: 1177763.6705 - val_loss: 970191.2500 - val_mae: 622.8217 - val_mse: 970191.2500
Epoch 13/100
10/10 [==============================] - 0s 6ms/step - loss: 1113369.6477 - mae: 659.2293 - mse: 1113369.6477 - val_loss: 970107.6875 - val_mae: 622.7516 - val_mse: 970107.6875
Epoch 14/100
10/10 [==============================] - 0s 6ms/step - loss: 1109816.4659 - mae: 658.1920 - mse: 1109816.4659 - val_loss: 970023.4375 - val_mae: 622.6811 - val_mse: 970023.4375
Epoch 15/100
10/10 [==============================] - 0s 6ms/step - loss: 1157198.1477 - mae: 658.9856 - mse: 1157198.1477 - val_loss: 969939.3750 - val_mae: 622.6106 - val_mse: 969939.3750
Epoch 16/100
10/10 [==============================] - 0s 6ms/step - loss: 1143274.8750 - mae: 657.9985 - mse: 1143274.8750 - val_loss: 969855.3750 - val_mae: 622.5402 - val_mse: 969855.3750
Epoch 17/100
10/10 [==============================] - 0s 6ms/step - loss: 1165513.4375 - mae: 670.8922 - mse: 1165513.4375 - val_loss: 969772.4375 - val_mae: 622.4705 - val_mse: 969772.4375
Epoch 18/100
10/10 [==============================] - 0s 7ms/step - loss: 1193134.1705 - mae: 671.8726 - mse: 1193134.1705 - val_loss: 969688.5000 - val_mae: 622.4000 - val_mse: 969688.5000
Epoch 19/100
10/10 [==============================] - 0s 8ms/step - loss: 1162164.2614 - mae: 663.2649 - mse: 1162164.2614 - val_loss: 969605.5000 - val_mae: 622.3302 - val_mse: 969605.5000
Epoch 20/100
10/10 [==============================] - 0s 6ms/step - loss: 1123151.9091 - mae: 658.6773 - mse: 1123151.9091 - val_loss: 969521.0625 - val_mae: 622.2594 - val_mse: 969521.0625
Epoch 21/100
10/10 [==============================] - 0s 6ms/step - loss: 1180219.1705 - mae: 672.3833 - mse: 1180219.1705 - val_loss: 969437.4375 - val_mae: 622.1892 - val_mse: 969437.4375
Epoch 22/100
10/10 [==============================] - 0s 6ms/step - loss: 1110395.1705 - mae: 658.6029 - mse: 1110395.1307 - val_loss: 969354.0625 - val_mae: 622.1191 - val_mse: 969354.0625
Epoch 23/100
10/10 [==============================] - 0s 6ms/step - loss: 1192435.0227 - mae: 670.6691 - mse: 1192435.0227 - val_loss: 969270.6250 - val_mae: 622.0491 - val_mse: 969270.6250
Epoch 24/100
10/10 [==============================] - 0s 6ms/step - loss: 1172744.5000 - mae: 668.4421 - mse: 1172744.5000 - val_loss: 969187.2500 - val_mae: 621.9789 - val_mse: 969187.2500
Epoch 25/100
10/10 [==============================] - 0s 6ms/step - loss: 1103317.3011 - mae: 655.8360 - mse: 1103317.3011 - val_loss: 969103.0625 - val_mae: 621.9084 - val_mse: 969103.0625
Epoch 26/100
10/10 [==============================] - 0s 6ms/step - loss: 1129796.2614 - mae: 660.2945 - mse: 1129796.2614 - val_loss: 969019.8125 - val_mae: 621.8384 - val_mse: 969019.8125
Epoch 27/100
10/10 [==============================] - 0s 6ms/step - loss: 1162134.1023 - mae: 665.4788 - mse: 1162134.1023 - val_loss: 968935.9375 - val_mae: 621.7680 - val_mse: 968935.9375
Epoch 28/100
10/10 [==============================] - 0s 6ms/step - loss: 1107790.3011 - mae: 660.0697 - mse: 1107790.1932 - val_loss: 968852.5000 - val_mae: 621.6978 - val_mse: 968852.5000
Epoch 29/100
10/10 [==============================] - 0s 6ms/step - loss: 1224679.1023 - mae: 675.6062 - mse: 1224679.1023 - val_loss: 968768.8750 - val_mae: 621.6275 - val_mse: 968768.8750
Epoch 30/100
10/10 [==============================] - 0s 6ms/step - loss: 1148625.7386 - mae: 663.3963 - mse: 1148625.7386 - val_loss: 968685.8125 - val_mae: 621.5575 - val_mse: 968685.8125
Epoch 31/100
10/10 [==============================] - 0s 22ms/step - loss: 1188917.5909 - mae: 670.0622 - mse: 1188917.5909 - val_loss: 968602.8125 - val_mae: 621.4877 - val_mse: 968602.8125
Epoch 32/100
10/10 [==============================] - 0s 6ms/step - loss: 1066631.7443 - mae: 648.4988 - mse: 1066631.7443 - val_loss: 968518.1875 - val_mae: 621.4167 - val_mse: 968518.1875
Epoch 33/100
10/10 [==============================] - 0s 7ms/step - loss: 1114294.8239 - mae: 661.9696 - mse: 1114294.8239 - val_loss: 968434.8125 - val_mae: 621.3465 - val_mse: 968434.8125
Epoch 34/100
10/10 [==============================] - 0s 6ms/step - loss: 1074721.6307 - mae: 650.8288 - mse: 1074721.6307 - val_loss: 968350.6250 - val_mae: 621.2759 - val_mse: 968350.6250
Epoch 35/100
10/10 [==============================] - 0s 6ms/step - loss: 1171792.5682 - mae: 663.4328 - mse: 1171792.5682 - val_loss: 968267.8750 - val_mae: 621.2062 - val_mse: 968267.8750
Epoch 36/100
10/10 [==============================] - 0s 6ms/step - loss: 1163355.4773 - mae: 669.8160 - mse: 1163355.4659 - val_loss: 968184.5625 - val_mae: 621.1361 - val_mse: 968184.5625
Epoch 37/100
10/10 [==============================] - 0s 6ms/step - loss: 1238443.6477 - mae: 680.7899 - mse: 1238443.6477 - val_loss: 968101.5000 - val_mae: 621.0661 - val_mse: 968101.5000
Epoch 38/100
10/10 [==============================] - 0s 6ms/step - loss: 1169701.4545 - mae: 667.5217 - mse: 1169701.4545 - val_loss: 968018.0000 - val_mae: 620.9960 - val_mse: 968018.0000
Epoch 39/100
10/10 [==============================] - 0s 6ms/step - loss: 1271506.4205 - mae: 687.1342 - mse: 1271506.4205 - val_loss: 967934.8750 - val_mae: 620.9259 - val_mse: 967934.8750
Epoch 40/100
10/10 [==============================] - 0s 6ms/step - loss: 1121816.8295 - mae: 660.9865 - mse: 1121816.8295 - val_loss: 967850.7500 - val_mae: 620.8553 - val_mse: 967850.7500
Epoch 41/100
10/10 [==============================] - 0s 6ms/step - loss: 1096724.3977 - mae: 658.3204 - mse: 1096724.3977 - val_loss: 967766.9375 - val_mae: 620.7850 - val_mse: 967766.9375
Epoch 42/100
10/10 [==============================] - 0s 6ms/step - loss: 1140753.9205 - mae: 659.0186 - mse: 1140753.9205 - val_loss: 967683.1875 - val_mae: 620.7144 - val_mse: 967683.1875
Epoch 43/100
10/10 [==============================] - 0s 6ms/step - loss: 1225529.2273 - mae: 678.8100 - mse: 1225529.2159 - val_loss: 967600.2500 - val_mae: 620.6446 - val_mse: 967600.2500
Epoch 44/100
10/10 [===========================

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)
Waitting for answers

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...