import torch import matplotlib.pyplot as plt import numpy as np plt.rcParams['font.family'] = 'SimHei' # --- 1. 准备非线性数据 --- # 我们来创建一个更复杂的数据模式,比如一条带有线性趋势的正弦曲线 # 真实函数: y = sin(x) + 0.5*x + noise TRUE_FUNCTION = lambda x: 2*x + 3 # TRUE_FUNCTION = lambda x: torch.sin(x) + 0.5 * x # 创建 100 个数据点 X = torch.linspace(0, 16 * np.pi, 120).unsqueeze(1) y = TRUE_FUNCTION(X) + torch.randn(X.size()) * 0.4 # 加入一些噪声 # 可视化我们创建的数据集 plt.figure(figsize=(10, 5)) plt.scatter(X.numpy(), y.numpy(), c='blue', label='原始数据点') plt.plot(X.numpy(), TRUE_FUNCTION(X).numpy(), 'g--', label='真实函数曲线') plt.title("非线性模拟数据") plt.xlabel("X") plt.ylabel("y") plt.legend() plt.grid(True) plt.show() # --- 2. 定义多项式模型和参数 --- # 我们的模型是 y_pred = a*x^3 + b*x^2 + c*x + d # 我们需要学习 4 个参数: a, b, c, d # 随机初始化它们,并设置 requires_grad=True a = torch.randn(1, requires_grad=True) b = torch.randn(1, requires_grad=True) c = torch.randn(1, requires_grad=True) d = torch.randn(1, requires_grad=True) print("--- 训练开始前 ---") print(f"随机初始化的参数: a={a.item():.3f}, b={b.item():.3f}, c={c.item():.3f}, d={d.item():.3f}") # --- 3. 定义损失函数和优化器 --- # 对于更复杂的模型,Adam 优化器通常表现更好 # 学习率可以适当调高一点 learning_rate = 0.01 optimizer = torch.optim.Adam([a, b, c, d], lr=learning_rate) # 损失函数仍然使用均方误差 loss_fn = torch.nn.MSELoss() # --- 4. 训练循环 --- # 增加训练周期数,因为模型更复杂,需要更多时间学习 epochs = 18000 all_losses = [] plt.figure(figsize=(12, 6)) plt.ion() for epoch in range(epochs): # 4.1 前向传播 (Forward Pass) - 这是唯一需要大改的地方! # 根据当前的 a, b, c, d 计算预测值 y_pred = a * X ** 3 + b * X ** 2 + c * X + d # 4.2 计算损失 loss = loss_fn(y_pred, y) all_losses.append(loss.item()) # 4.3 清空过往梯度 optimizer.zero_grad() # 4.4 反向传播 - 核心步骤,但代码完全不变! # PyTorch 自动处理复杂的求导链条 loss.backward() # 4.5 更新参数 - 代码也完全不变! optimizer.step() # --- 可视化 --- if (epoch + 1) % 100 == 0: print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}') plt.clf() plt.scatter(X.numpy(), y.numpy(), c='blue', s=15, label='原始数据') plt.plot(X.numpy(), TRUE_FUNCTION(X).numpy(), 'g--', label='真实函数曲线') # detach() 是为了切断梯度追踪,因为绘图不需要计算梯度 plt.plot(X.numpy(), y_pred.detach().numpy(), 'r-', label='当前拟合曲线') plt.title(f"训练过程 - Epoch {epoch + 1}") plt.xlabel("X") plt.ylabel("y") plt.ylim(y.min() - 1, y.max() + 1) # 固定Y轴范围,防止图像跳动 plt.legend() plt.grid(True) plt.pause(0.1) plt.ioff() plt.show() # --- 5. 最终结果可视化 --- print("\n--- 训练完成 ---") print(f"学习到的参数: a={a.item():.3f}, b={b.item():.3f}, c={c.item():.3f}, d={d.item():.3f}") plt.figure(figsize=(10, 5)) plt.scatter(X.numpy(), y.numpy(), c='blue', s=15, label='原始数据') plt.plot(X.numpy(), TRUE_FUNCTION(X).numpy(), 'g--', label='真实函数曲线') final_pred = a * X ** 3 + b * X ** 2 + c * X + d plt.plot(X.numpy(), final_pred.detach().numpy(), 'r-', label='最终拟合曲线') plt.title("最终拟合结果") plt.xlabel("X") plt.ylabel("y") plt.legend() plt.grid(True) plt.show() # 绘制损失函数下降曲线 plt.figure(figsize=(10, 5)) plt.plot(range(epochs), all_losses) plt.title("损失函数下降曲线") plt.xlabel("周期 (Epoch)") plt.ylabel("损失 (Loss)") plt.yscale('log') # 使用对数坐标轴,可以更清晰地看到早期的快速下降和后期的缓慢优化 plt.grid(True) plt.show()