| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- import torch
- import matplotlib.pyplot as plt
- import numpy as np
- plt.rcParams['font.family'] = 'SimHei'
- # --- 1. 准备非线性数据 ---
- # 我们来创建一个更复杂的数据模式,比如一条带有线性趋势的正弦曲线
- # 真实函数: y = sin(x) + 0.5*x + noise
- TRUE_FUNCTION = lambda x: 2*x + 3
- # TRUE_FUNCTION = lambda x: torch.sin(x) + 0.5 * x
- # 创建 100 个数据点
- X = torch.linspace(0, 16 * np.pi, 120).unsqueeze(1)
- y = TRUE_FUNCTION(X) + torch.randn(X.size()) * 0.4 # 加入一些噪声
- # 可视化我们创建的数据集
- plt.figure(figsize=(10, 5))
- plt.scatter(X.numpy(), y.numpy(), c='blue', label='原始数据点')
- plt.plot(X.numpy(), TRUE_FUNCTION(X).numpy(), 'g--', label='真实函数曲线')
- plt.title("非线性模拟数据")
- plt.xlabel("X")
- plt.ylabel("y")
- plt.legend()
- plt.grid(True)
- plt.show()
- # --- 2. 定义多项式模型和参数 ---
- # 我们的模型是 y_pred = a*x^3 + b*x^2 + c*x + d
- # 我们需要学习 4 个参数: a, b, c, d
- # 随机初始化它们,并设置 requires_grad=True
- a = torch.randn(1, requires_grad=True)
- b = torch.randn(1, requires_grad=True)
- c = torch.randn(1, requires_grad=True)
- d = torch.randn(1, requires_grad=True)
- print("--- 训练开始前 ---")
- print(f"随机初始化的参数: a={a.item():.3f}, b={b.item():.3f}, c={c.item():.3f}, d={d.item():.3f}")
- # --- 3. 定义损失函数和优化器 ---
- # 对于更复杂的模型,Adam 优化器通常表现更好
- # 学习率可以适当调高一点
- learning_rate = 0.01
- optimizer = torch.optim.Adam([a, b, c, d], lr=learning_rate)
- # 损失函数仍然使用均方误差
- loss_fn = torch.nn.MSELoss()
- # --- 4. 训练循环 ---
- # 增加训练周期数,因为模型更复杂,需要更多时间学习
- epochs = 18000
- all_losses = []
- plt.figure(figsize=(12, 6))
- plt.ion()
- for epoch in range(epochs):
- # 4.1 前向传播 (Forward Pass) - 这是唯一需要大改的地方!
- # 根据当前的 a, b, c, d 计算预测值
- y_pred = a * X ** 3 + b * X ** 2 + c * X + d
- # 4.2 计算损失
- loss = loss_fn(y_pred, y)
- all_losses.append(loss.item())
- # 4.3 清空过往梯度
- optimizer.zero_grad()
- # 4.4 反向传播 - 核心步骤,但代码完全不变!
- # PyTorch 自动处理复杂的求导链条
- loss.backward()
- # 4.5 更新参数 - 代码也完全不变!
- optimizer.step()
- # --- 可视化 ---
- if (epoch + 1) % 100 == 0:
- print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')
- plt.clf()
- plt.scatter(X.numpy(), y.numpy(), c='blue', s=15, label='原始数据')
- plt.plot(X.numpy(), TRUE_FUNCTION(X).numpy(), 'g--', label='真实函数曲线')
- # detach() 是为了切断梯度追踪,因为绘图不需要计算梯度
- plt.plot(X.numpy(), y_pred.detach().numpy(), 'r-', label='当前拟合曲线')
- plt.title(f"训练过程 - Epoch {epoch + 1}")
- plt.xlabel("X")
- plt.ylabel("y")
- plt.ylim(y.min() - 1, y.max() + 1) # 固定Y轴范围,防止图像跳动
- plt.legend()
- plt.grid(True)
- plt.pause(0.1)
- plt.ioff()
- plt.show()
- # --- 5. 最终结果可视化 ---
- print("\n--- 训练完成 ---")
- print(f"学习到的参数: a={a.item():.3f}, b={b.item():.3f}, c={c.item():.3f}, d={d.item():.3f}")
- plt.figure(figsize=(10, 5))
- plt.scatter(X.numpy(), y.numpy(), c='blue', s=15, label='原始数据')
- plt.plot(X.numpy(), TRUE_FUNCTION(X).numpy(), 'g--', label='真实函数曲线')
- final_pred = a * X ** 3 + b * X ** 2 + c * X + d
- plt.plot(X.numpy(), final_pred.detach().numpy(), 'r-', label='最终拟合曲线')
- plt.title("最终拟合结果")
- plt.xlabel("X")
- plt.ylabel("y")
- plt.legend()
- plt.grid(True)
- plt.show()
- # 绘制损失函数下降曲线
- plt.figure(figsize=(10, 5))
- plt.plot(range(epochs), all_losses)
- plt.title("损失函数下降曲线")
- plt.xlabel("周期 (Epoch)")
- plt.ylabel("损失 (Loss)")
- plt.yscale('log') # 使用对数坐标轴,可以更清晰地看到早期的快速下降和后期的缓慢优化
- plt.grid(True)
- plt.show()
|