Tensorflow實現(xiàn)線性回歸模型的示例代碼_第1頁
Tensorflow實現(xiàn)線性回歸模型的示例代碼_第2頁
Tensorflow實現(xiàn)線性回歸模型的示例代碼_第3頁
Tensorflow實現(xiàn)線性回歸模型的示例代碼_第4頁
Tensorflow實現(xiàn)線性回歸模型的示例代碼_第5頁
全文預覽已結(jié)束

下載本文檔

版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請進行舉報或認領(lǐng)

文檔簡介

第Tensorflow實現(xiàn)線性回歸模型的示例代碼目錄1.線性與非線性回歸案例講解1.數(shù)據(jù)集2.讀取訓練數(shù)據(jù)Income.csv并可視化展示3.利用Tensorflow搭建和訓練神經(jīng)網(wǎng)絡(luò)模型【線性回歸模型的建立】4.模型預測

1.線性與非線性回歸

線性回歸LinearRegression:兩個變量之間的關(guān)系是一次函數(shù)關(guān)系的圖像是直線,叫做線性。線性是指廣義的線性,也就是數(shù)據(jù)與數(shù)據(jù)之間的關(guān)系,如圖x1。

非線性回歸:兩個變量之間的關(guān)系不是一次函數(shù)關(guān)系的圖像不是直線,叫做非線性,如圖x2。

一元線性回歸:只包括一個自變量和一個因變量,且二者的關(guān)系可用一條直線近似表示,這種回歸分析稱為一元線性回歸分析。函數(shù)表達:y=bx+a。

多元線性回歸:包括兩個或兩個以上相互獨立的自變量(x1,x2,x3...),且因變量(y)和自變量之間是線性關(guān)系,則稱為多元線性回歸分析。函數(shù)表達:

線性回歸在深度學習中的應(yīng)用:在深度學習中,我們就是要根據(jù)已知數(shù)據(jù)點(自變量)和因變量(y)去訓練模型得到未知參數(shù)a和b、和的具體值,從而得到預測模型,在這里()相當于深度學習中目標對象的特征,(y)相當于具體的目標對象。得到預測模型之后再對未知的自變量x進行預測,得到預測的y。

線性回歸問題與分類問題:與回歸相對的是分類問題(classification),分類問題預測輸出的y值是有限的,預測值y只能是有限集合內(nèi)的一個。而當要預測值y輸出集合是無限且連續(xù),我們稱之為回歸。比如,天氣預報預測明天是否下雨,是一個二分類問題;預測明天的降雨量多少,就是一個回歸問題。

案例講解

了解基礎(chǔ)概念之后,使用Tensorflow實現(xiàn)一個簡單的一元線性回歸問題,調(diào)查學歷和收入之間的線性關(guān)系,如下所示:

求解未知參數(shù)a和b的方法:

1.數(shù)據(jù)集

模型訓練的數(shù)據(jù)存儲在一個.csv文件里,Education代表學歷【自變量x】,Income代表收入【因變量y】。

目標:我們要利用已知的Education和income數(shù)據(jù)值,求解未知參數(shù)a和b的值,得到Education和Income之間的線性關(guān)系。

2.讀取訓練數(shù)據(jù)Income.csv并可視化展示

importtensorflowastf

importnumpyasnp

#1.查看tensorflow版本

print("TensorflowVersion{}".format(tf.__version__))

#2.pandas讀取包含線性關(guān)系的.csv文件

importpandasaspd

data=pd.read_csv('D:\Project\TesorFlow\datasets\Income.csv')

print(data)

#3.繪制線性回歸關(guān)系-散點圖

importmatplotlib.pyplotasplt

plt.scatter(data.Education,data.Income)

plt.show()

3.利用Tensorflow搭建和訓練神經(jīng)網(wǎng)絡(luò)模型【線性回歸模型的建立】

#4.順序模型squential的建立

#順序模型是指網(wǎng)絡(luò)是一層一層搭建的,前面一層的輸出是后一層的輸入。

model=tf.keras.Sequential()

model.add(tf.keras.layers.Dense(1,input_shape=(1,)))

#dense(輸出數(shù)據(jù)的維度,輸入數(shù)據(jù)的維度)

#5.查看模型的結(jié)構(gòu)

model.summary()

#6.編譯模型-配置的過程,優(yōu)化算法方式(梯度下降)、損失函數(shù)

#Adam優(yōu)化器的學習速率默認為0.01

pile(optimizer='adam',

loss='mse')

#7.訓練模型,記錄模型的訓練過程history

#訓練過程是loss函數(shù)值降低的過程:

#即不斷逼近最優(yōu)的a和b參數(shù)值的過程

#這個過程要訓練很多次epoch,epoch是指對所有訓練數(shù)據(jù)訓練的次數(shù)

history=model.fit(x,y,epochs=100)

model.summary():查看我們創(chuàng)建的神經(jīng)網(wǎng)絡(luò)模型,這里我們只添加了一層全連接層。

訓練過程:這里只訓練100個epoch.

4.模型預測

#8.已知數(shù)據(jù)預測

model.predict(x)

print(model.predict(x))

#9.隨機數(shù)據(jù)預測:

#"""

#注意:pandas數(shù)據(jù)結(jié)構(gòu)是數(shù)據(jù)框DataFrame和序列Series

#序列(Series)是二維表格中的一列或者一行。實際上,當訪問DataFrame的一行時,pandas自動把該行轉(zhuǎn)換為序列;當訪問DataFrame的一列時,Pandas也自動把該列轉(zhuǎn)換為序列。

#序列是由一組數(shù)據(jù)(各種NumPy數(shù)據(jù)類型),以及一組與之相關(guān)的數(shù)據(jù)標簽(索引)組成,序列不要求數(shù)據(jù)類型是相同的,序列可以看作是一維數(shù)組(一行或一列)

#序列的表現(xiàn)形式為:索引在左邊,值在右邊。由于沒有顯式為Series指定索引,pandas會自動創(chuàng)建一個從0到N-1的整數(shù)型索引。

#"""

#test_predict=model.predict(pd.Series([20]))#所以這里輸入時需要將其轉(zhuǎn)換為Series結(jié)構(gòu)

test_predict=model.predict(pd.Series([10,20]))#預測的數(shù)據(jù)為10和20

print(test_predict)

print(pd.DataFrame([(10,20,30)]))

已知結(jié)果的數(shù)據(jù)預測的結(jié)果:查看我們創(chuàng)建的神經(jīng)網(wǎng)絡(luò)模型,這里我們只添加了一層全連接層。

未知結(jié)果的數(shù)據(jù)預測的結(jié)果:可以看到預測結(jié)果很差,說明我們的神經(jīng)網(wǎng)絡(luò)模型并沒有訓練好,求解得到的未知參數(shù)的a和b的

溫馨提示

  • 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
  • 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
  • 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會有圖紙預覽,若沒有圖紙預覽就沒有圖紙。
  • 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
  • 5. 人人文庫網(wǎng)僅提供信息存儲空間,僅對用戶上傳內(nèi)容的表現(xiàn)方式做保護處理,對用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對任何下載內(nèi)容負責。
  • 6. 下載文件中如有侵權(quán)或不適當內(nèi)容,請與我們聯(lián)系,我們立即糾正。
  • 7. 本站不保證下載資源的準確性、安全性和完整性, 同時也不承擔用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。

評論

0/150

提交評論