跳转至

🟠 梯度下降法求极值\#

一开始只是一个很简单的应用,但是想起了我所学习过的深度学习知识,我又忍不住将梯度下降这个算法给补充完整

采用了最基础的梯度下降法求函数极值,一开始我的想法是引入外部库求导得到极值,但是这个函数参数有点多,弄得很混乱,浅浅尝试了一下发现报错就没继续了,于是决定采用定步长的梯度下降算法

在这个例子中,定步长的梯度下降算法虽然简单,但是好用,且对精度要求也不高,计算量也不大,最关键的是我们可以估计到极值点的大概位置,这可以帮助我们判断算法的正确性

想要进一步查看梯度下降算法,查看: 1. Gradient descent 2. 梯度下降法的数学原理

'''
需要用到的依赖,通过下面的命令安装:
命令行中一行一行粘贴下面的代码
pip install numpy
pip install matplotlib
--------------------------------
弹簧的弹性特性描述:
mu = 0.3
R = 125
r = 100
R_1 = 120
r_1 = 105
H = 4
h = 2.5
E = 2.06e5 # 单位MPa
'''

import numpy as np
from matplotlib import pyplot as plt

plt.rc('font', family = 'Songti SC')
# 常数
mu = 0.3
R = 125
r = 100
R_1 = 120
r_1 = 105
H = 4
h = 2.5
E = 2.06e5 # 单位MPa

rambda = np.linspace(0, 6, 10000)

def ln(x):
    return np.log(x)/np.log(np.e)

F = np.pi*E*h*rambda/(6*(1-mu**2)) * ln(R/r)/((R_1-r_1)**2) * ((H-rambda*(R-r)/(R_1-r_1))*(H-rambda/2*(R-r)/(R_1-r_1))+h**2)

# ---------------求极值点-----------------

def critical_point(x, f, critical):
    if critical == 'max':
        max = f[0]
        for i in range(1, len(x)-1):
            if (f[i] > f[i-1]) & (f[i] > f[i+1]):
                max = f[i]
                x_label = 6*i/10000 
                print(f"极小值为: ({x_label}, {np.round(max, 2)})" ) 
                return (x_label, max)

    if critical == 'min':
       for i in range(1, len(x)-1):
            if (f[i] < f[i-1]) & (f[i] < f[i+1]):
                min = f[i]
                x_label = 6*i/10000 
                print(f"极大值为: ({x_label}, {np.round(min, 2)})" ) 
                return (x_label, min)

max_point = critical_point(rambda, F, 'max')
min_point = critical_point(rambda, F, 'min')

# ---------------绘图---------------------
plt.plot(rambda, F, color = '#61bbf2')
# 极值点
plt.scatter(max_point[0], max_point[1], color='#b861f2')
plt.scatter(min_point[0], min_point[1], color='#b861f2')

plt.annotate(f"({max_point[0]}, {np.round(max_point[1], 2)})", xy=(1.2, 5000), xytext=(1.2, 5000), ha='left', va='top', fontsize=14)
plt.annotate(f"({min_point[0]}, {np.round(min_point[1], 2)})", xy=(2.5, 4000), xytext=(2.5, 4000), ha='left', va='top', fontsize=14)


plt.title("膜片弹簧工作点的位置")
plt.xlabel(r"变形  $\mathrm{x/min}$")
plt.ylabel(r"工作压力 $\mathrm{F/N}$")
plt.xlim(0,4.2)
plt.ylim(0, 6000)
plt.show() 

运行上面的代码得到:

膜片弹簧工作点的位置