test01.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import torch
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. plt.rcParams['font.family'] = 'SimHei'
  5. # --- 1. 准备非线性数据 ---
  6. # 我们来创建一个更复杂的数据模式,比如一条带有线性趋势的正弦曲线
  7. # 真实函数: y = sin(x) + 0.5*x + noise
  8. TRUE_FUNCTION = lambda x: 2*x + 3
  9. # TRUE_FUNCTION = lambda x: torch.sin(x) + 0.5 * x
  10. # 创建 100 个数据点
  11. X = torch.linspace(0, 16 * np.pi, 120).unsqueeze(1)
  12. y = TRUE_FUNCTION(X) + torch.randn(X.size()) * 0.4 # 加入一些噪声
  13. # 可视化我们创建的数据集
  14. plt.figure(figsize=(10, 5))
  15. plt.scatter(X.numpy(), y.numpy(), c='blue', label='原始数据点')
  16. plt.plot(X.numpy(), TRUE_FUNCTION(X).numpy(), 'g--', label='真实函数曲线')
  17. plt.title("非线性模拟数据")
  18. plt.xlabel("X")
  19. plt.ylabel("y")
  20. plt.legend()
  21. plt.grid(True)
  22. plt.show()
  23. # --- 2. 定义多项式模型和参数 ---
  24. # 我们的模型是 y_pred = a*x^3 + b*x^2 + c*x + d
  25. # 我们需要学习 4 个参数: a, b, c, d
  26. # 随机初始化它们,并设置 requires_grad=True
  27. a = torch.randn(1, requires_grad=True)
  28. b = torch.randn(1, requires_grad=True)
  29. c = torch.randn(1, requires_grad=True)
  30. d = torch.randn(1, requires_grad=True)
  31. print("--- 训练开始前 ---")
  32. print(f"随机初始化的参数: a={a.item():.3f}, b={b.item():.3f}, c={c.item():.3f}, d={d.item():.3f}")
  33. # --- 3. 定义损失函数和优化器 ---
  34. # 对于更复杂的模型,Adam 优化器通常表现更好
  35. # 学习率可以适当调高一点
  36. learning_rate = 0.01
  37. optimizer = torch.optim.Adam([a, b, c, d], lr=learning_rate)
  38. # 损失函数仍然使用均方误差
  39. loss_fn = torch.nn.MSELoss()
  40. # --- 4. 训练循环 ---
  41. # 增加训练周期数,因为模型更复杂,需要更多时间学习
  42. epochs = 18000
  43. all_losses = []
  44. plt.figure(figsize=(12, 6))
  45. plt.ion()
  46. for epoch in range(epochs):
  47. # 4.1 前向传播 (Forward Pass) - 这是唯一需要大改的地方!
  48. # 根据当前的 a, b, c, d 计算预测值
  49. y_pred = a * X ** 3 + b * X ** 2 + c * X + d
  50. # 4.2 计算损失
  51. loss = loss_fn(y_pred, y)
  52. all_losses.append(loss.item())
  53. # 4.3 清空过往梯度
  54. optimizer.zero_grad()
  55. # 4.4 反向传播 - 核心步骤,但代码完全不变!
  56. # PyTorch 自动处理复杂的求导链条
  57. loss.backward()
  58. # 4.5 更新参数 - 代码也完全不变!
  59. optimizer.step()
  60. # --- 可视化 ---
  61. if (epoch + 1) % 100 == 0:
  62. print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')
  63. plt.clf()
  64. plt.scatter(X.numpy(), y.numpy(), c='blue', s=15, label='原始数据')
  65. plt.plot(X.numpy(), TRUE_FUNCTION(X).numpy(), 'g--', label='真实函数曲线')
  66. # detach() 是为了切断梯度追踪,因为绘图不需要计算梯度
  67. plt.plot(X.numpy(), y_pred.detach().numpy(), 'r-', label='当前拟合曲线')
  68. plt.title(f"训练过程 - Epoch {epoch + 1}")
  69. plt.xlabel("X")
  70. plt.ylabel("y")
  71. plt.ylim(y.min() - 1, y.max() + 1) # 固定Y轴范围,防止图像跳动
  72. plt.legend()
  73. plt.grid(True)
  74. plt.pause(0.1)
  75. plt.ioff()
  76. plt.show()
  77. # --- 5. 最终结果可视化 ---
  78. print("\n--- 训练完成 ---")
  79. print(f"学习到的参数: a={a.item():.3f}, b={b.item():.3f}, c={c.item():.3f}, d={d.item():.3f}")
  80. plt.figure(figsize=(10, 5))
  81. plt.scatter(X.numpy(), y.numpy(), c='blue', s=15, label='原始数据')
  82. plt.plot(X.numpy(), TRUE_FUNCTION(X).numpy(), 'g--', label='真实函数曲线')
  83. final_pred = a * X ** 3 + b * X ** 2 + c * X + d
  84. plt.plot(X.numpy(), final_pred.detach().numpy(), 'r-', label='最终拟合曲线')
  85. plt.title("最终拟合结果")
  86. plt.xlabel("X")
  87. plt.ylabel("y")
  88. plt.legend()
  89. plt.grid(True)
  90. plt.show()
  91. # 绘制损失函数下降曲线
  92. plt.figure(figsize=(10, 5))
  93. plt.plot(range(epochs), all_losses)
  94. plt.title("损失函数下降曲线")
  95. plt.xlabel("周期 (Epoch)")
  96. plt.ylabel("损失 (Loss)")
  97. plt.yscale('log') # 使用对数坐标轴,可以更清晰地看到早期的快速下降和后期的缓慢优化
  98. plt.grid(True)
  99. plt.show()