Răsfoiți Sursa

使用gradio

AnlaAnla 1 an în urmă
părinte
comite
4e8ffe4b54
12 a modificat fișierele cu 278 adăugiri și 258 ștergeri
  1. 74 4
      .idea/deployment.xml
  2. 4 0
      .idea/train_model_server.iml
  3. 8 0
      Config.py
  4. 0 38
      README.MD
  5. 0 92
      data_transmit.py
  6. 9 0
      requirements.txt
  7. 0 12
      show_server.py
  8. 0 17
      train_params.json
  9. 41 0
      utils/MyModel.py
  10. 59 0
      utils/MyOnnxYolo.py
  11. 59 95
      utils/Train_resnet.py
  12. 24 0
      utils/preview_img.py

+ 74 - 4
.idea/deployment.xml

@@ -1,22 +1,92 @@
 <?xml version="1.0" encoding="UTF-8"?>
 <project version="4">
-  <component name="PublishConfigData" autoUpload="Always" serverName="martin@192.168.56.116:22 password" remoteFilesAllowedToDisappearOnAutoupload="false">
+  <component name="PublishConfigData" autoUpload="Always" serverName="martin@192.168.66.117 password" remoteFilesAllowedToDisappearOnAutoupload="false">
     <serverData>
-      <paths name="martin@192.168.56.116:22 password">
+      <paths name="er@192.168.56.169:22 password">
+        <serverdata>
+          <mappings>
+            <mapping local="$PROJECT_DIR$" web="/" />
+          </mappings>
+        </serverdata>
+      </paths>
+      <paths name="er@192.168.56.169:22 password (2)">
+        <serverdata>
+          <mappings>
+            <mapping local="$PROJECT_DIR$" web="/" />
+          </mappings>
+        </serverdata>
+      </paths>
+      <paths name="martin@100.64.1.9:22 password">
+        <serverdata>
+          <mappings>
+            <mapping local="$PROJECT_DIR$" web="/" />
+          </mappings>
+        </serverdata>
+      </paths>
+      <paths name="martin@100.64.1.9:22 password (2)">
+        <serverdata>
+          <mappings>
+            <mapping local="$PROJECT_DIR$" web="/" />
+          </mappings>
+        </serverdata>
+      </paths>
+      <paths name="martin@100.64.1.9:22 password (3)">
+        <serverdata>
+          <mappings>
+            <mapping local="$PROJECT_DIR$" web="/" />
+          </mappings>
+        </serverdata>
+      </paths>
+      <paths name="martin@100.64.1.9:22 password (4)">
+        <serverdata>
+          <mappings>
+            <mapping local="$PROJECT_DIR$" web="/" />
+          </mappings>
+        </serverdata>
+      </paths>
+      <paths name="martin@100.64.1.9:22 password (5)">
+        <serverdata>
+          <mappings>
+            <mapping local="$PROJECT_DIR$" web="/" />
+          </mappings>
+        </serverdata>
+      </paths>
+      <paths name="martin@192.168.66.117 password">
         <serverdata>
           <mappings>
             <mapping deploy="/media/martin/DATA/_ML/RemoteProject/train_model_server" local="$PROJECT_DIR$" web="/" />
           </mappings>
         </serverdata>
       </paths>
-      <paths name="pi@192.168.56.156:22 password">
+      <paths name="martin@192.168.66.117:22 password (6)">
+        <serverdata>
+          <mappings>
+            <mapping local="$PROJECT_DIR$" web="/" />
+          </mappings>
+        </serverdata>
+      </paths>
+      <paths name="pi@192.168.56.117:22 password">
+        <serverdata>
+          <mappings>
+            <mapping local="$PROJECT_DIR$" web="/" />
+          </mappings>
+        </serverdata>
+      </paths>
+      <paths name="pi@192.168.66.116:22 password">
+        <serverdata>
+          <mappings>
+            <mapping local="$PROJECT_DIR$" web="/" />
+          </mappings>
+        </serverdata>
+      </paths>
+      <paths name="pi@192.168.66.156:22 password">
         <serverdata>
           <mappings>
             <mapping local="$PROJECT_DIR$" web="/" />
           </mappings>
         </serverdata>
       </paths>
-      <paths name="pi@192.168.56.156:22 password (2)">
+      <paths name="pi@192.168.66.156:22 password (2)">
         <serverdata>
           <mappings>
             <mapping local="$PROJECT_DIR$" web="/" />

+ 4 - 0
.idea/train_model_server.iml

@@ -5,4 +5,8 @@
     <orderEntry type="jdk" jdkName="yolov8" jdkType="Python SDK" />
     <orderEntry type="sourceFolder" forTests="false" />
   </component>
+  <component name="PackageRequirementsSettings">
+    <option name="versionSpecifier" value="Greater or equal (&gt;=x.y.z)" />
+    <option name="removeUnused" value="true" />
+  </component>
 </module>

+ 8 - 0
Config.py

@@ -0,0 +1,8 @@
+class Config:
+    data = 1
+    def __init__(self):
+        # 用于存储已创建的目录名称
+        self.data_dir = './data'
+        self.model_dir = './runs'
+
+        self.yolo_model_path = "./Model/yolo_handcard01.onnx"

+ 0 - 38
README.MD

@@ -1,38 +0,0 @@
-## train 0.0.0.0:port=6661
-###  运行train.py后, 等待传递参数进行训练
-```python
-# 发送请求方法
-import json
-import requests
-
-with open('train_params.json', 'r', encoding='utf-8') as f:
-    data = f.read()
-url = "http://192.168.56.116:6661/train/params_json"
-response = requests.post(url, files={"file": data})
-
-print(response)
-```
-
-## 查询训练过程  0.0.0.0:port=6662
-### 服务器运行命令
-```bash
-tensorboard --logdir=runs --host=0.0.0.0 --port 6662
-```
-### 客户端访问 http://0.0.0.0:6662
-
-
-## 客户端请求服务器数据
-```bash
-上传文件:
-curl -X POST -F "file=@/path/to/your/file.zip" http://localhost:6664/upload
-
-下载文件:
-curl -X GET http://localhost:6664/download/file.zip -o downloaded_file.zip
-
-删除文件:
-curl -X POST http://localhost:6664/delete/file.zip
-
-清空 uploads 目录:
-curl -X POST http://localhost:6664/clear
-
-```

+ 0 - 92
data_transmit.py

@@ -1,92 +0,0 @@
-from fastapi import FastAPI, File, UploadFile
-from fastapi.responses import FileResponse
-import os
-import shutil
-import zipfile
-
-app = FastAPI()
-
-# 存储上传的文件的目录
-DATA_FOLDER = 'data'
-MODEL_FOLDER = 'runs'
-if not os.path.exists(DATA_FOLDER):
-    os.makedirs(DATA_FOLDER)
-
-if not os.path.exists(MODEL_FOLDER):
-    os.makedirs(MODEL_FOLDER)
-
-
-@app.post("/upload")
-async def upload_file(file: UploadFile):
-    data_dir = os.path.join(DATA_FOLDER, os.path.splitext(file.filename)[0])
-
-    # 排除名称重复的数据
-    if os.path.exists(data_dir):
-        return {"error": "存在重名数据"}
-
-    # 检查文件是否为 ZIP 格式
-    if file.filename.endswith('.zip'):
-        # 保存 ZIP 文件到 uploads 目录
-        zip_file_path = os.path.join(DATA_FOLDER, file.filename)
-        with open(zip_file_path, "wb") as f:
-            f.write(await file.read())
-
-        # 解压 ZIP 文件
-        with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
-
-            os.makedirs(data_dir, exist_ok=True)
-            zip_ref.extractall(data_dir)
-
-        # 删除原始 ZIP 文件
-        os.remove(zip_file_path)
-
-        return {"message": f"数据 '{file.filename}' 上传成功"}
-    else:
-        return {"error": "我只要zip文件"}
-
-
-@app.get("/download/{model_id}")
-async def download_file(model_id: str):
-    # 从 runs 目录下载指定id的模型文件
-    folder_path = os.path.join(DATA_FOLDER, model_id)
-    if not os.path.exists(folder_path):
-        return {f"{folder_path} 不存在"}
-
-    model_name = [name for name in os.listdir(folder_path) if '.pth' in name][0]
-    model_path = os.path.join(folder_path, model_name)
-
-    return FileResponse(model_path, media_type='application/octet-stream', filename=model_name)
-
-
-# @app.post("/delete/{filename}")
-# async def delete_file(filename: str):
-#     # 删除 uploads 目录下的文件
-#     file_path = os.path.join(DATA_FOLDER, filename)
-#     os.remove(file_path)
-#     return {"message": f"文件 '{filename}' 成功删除"}
-
-@app.post("/check/{dir_name}")
-async def clear_uploads(dir_name: str):
-    # 查询 dir_name 目录
-    if dir_name == 'data':
-        data = os.listdir(DATA_FOLDER)
-    else:
-        data = os.listdir(MODEL_FOLDER)
-    return data
-
-
-@app.post("/clear/{dir_name}")
-async def clear_uploads(dir_name: str):
-    # 清空 dir_name 目录
-    if dir_name == 'data':
-        shutil.rmtree(DATA_FOLDER)
-        os.makedirs(DATA_FOLDER)
-    elif dir_name == 'models':
-        shutil.rmtree(MODEL_FOLDER)
-    return {dir_name, " 目录已经清空"}
-
-
-if __name__ == "__main__":
-    import uvicorn
-
-    uvicorn.run(app, host="0.0.0.0", port=6664)

+ 9 - 0
requirements.txt

@@ -0,0 +1,9 @@
+torch>=2.3.0
+numpy>=1.22.3
+torchvision>=0.18.0
+gradio>=4.32.2
+psutil>=5.9.8
+uvicorn>=0.28.0
+fastapi>=0.110.0
+starlette>=0.36.3
+resnest --pre

+ 0 - 12
show_server.py

@@ -1,12 +0,0 @@
-import os
-
-
-if __name__ == '__main__':
-    folder_path = './newest_log'
-    if not os.path.exists(folder_path):
-        os.mkdir(folder_path)
-
-    if len(os.listdir(folder_path)) == 0:
-        print('没有 log')
-    else:
-        print(os.system(f'tensorboard --logdir={folder_path} --host=0.0.0.0 --port 6662'))

+ 0 - 17
train_params.json

@@ -1,17 +0,0 @@
-{
-  "model_name": "resnet50",
-  "data_dir_path": "data/data_test",
-
-  "load_my_model": false,
-  "model_path": "",
-
-  "is_transfer_learn": false,
-  "transfer_cls": 12796,
-
-  "params":{
-    "epoch": 30,
-    "learn_rate": 0.01,
-    "momentum": 0.9,
-    "decay": 0.7
-  }
-}

+ 41 - 0
utils/MyModel.py

@@ -0,0 +1,41 @@
+import torch
+import torchvision.transforms as transforms
+from PIL import Image
+
+
+class MyModel:
+    def __init__(self, model_path: str) -> None:
+
+        self.norm_mean = [0.485, 0.456, 0.406]
+        self.norm_std = [0.229, 0.224, 0.225]
+
+        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+        self.model = torch.load(model_path)
+        self.model.eval()
+
+    def inference_transform(self):
+        inference_transform = transforms.Compose([
+            transforms.Resize((224, 224)),
+            transforms.ToTensor(),
+            transforms.Normalize(self.norm_mean, self.norm_std),
+        ])
+        return inference_transform
+
+    # 输入图片, 获取图片特征向量
+    def run(self, img):
+        if type(img) == type('path'):
+            img = Image.open(img).convert('RGB')
+        else:
+            img = Image.fromarray(img)
+            img = img.convert('RGB')
+        transform = self.inference_transform()
+
+        img_tensor = transform(img)
+        img_tensor = img_tensor.unsqueeze(0).to(self.device)
+
+        # Perform prediction
+        with torch.no_grad():
+            outputs = self.model(img_tensor)
+            _, predicted = torch.max(outputs.data, 1)
+        return int(predicted)

+ 59 - 0
utils/MyOnnxYolo.py

@@ -0,0 +1,59 @@
+from ultralytics import YOLO
+
+'''用法
+model = MyOnnxYolo(r"yolo_handcard01.onnx")
+img = Image.open(img_path).convert('RGB')
+model.set_result(img)
+yolo_img = model.get_max_img(cls_id=0)
+'''
+class MyOnnxYolo:
+    # cls_id {card:0, person:1, hand:2}
+    def __init__(self, model_path):
+
+        # 加载yolo model
+
+        self.model = YOLO(model_path, verbose=False, task='detect')
+        self.results = None
+
+        self.cls = None
+        self.boxes = None
+        self.img = None
+
+    def set_result(self, img, imgsz=640):
+        self.results = self.model.predict(img, max_det=3, verbose=False, imgsz=imgsz)
+
+        self.img = self.results[0].orig_img
+        self.boxes = self.results[0].boxes.xyxy.cpu()
+        self.cls = self.results[0].boxes.cls
+
+    def check(self, cls_id: int):
+        if cls_id in self.cls:
+            return True
+        return False
+
+    def get_max_img(self, cls_id: int):
+        # cls_id {card:0, person:1, hand:2}
+        # 排除没有检测到物体 或 截取的id不存在的图片
+        if len(self.boxes) == 0 or cls_id not in self.cls:
+            return self.img
+
+        max_area = 0
+        # 选出最大的卡片框
+        x1, y1, x2, y2 = 0, 0, 0, 0
+        for i, box in enumerate(self.boxes):
+            if self.cls[i] != cls_id:
+                continue
+
+            temp_x1, temp_y1, temp_x2, temp_y2 = box
+            area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1)
+            if area > max_area:
+                max_area = area
+                x1, y1, x2, y2 = temp_x1, temp_y1, temp_x2, temp_y2
+
+        x1 = int(x1)
+        x2 = int(x2)
+        y1 = int(y1)
+        y2 = int(y2)
+        max_img = self.img[y1:y2, x1:x2, :]
+
+        return max_img

+ 59 - 95
train.py → utils/Train_resnet.py

@@ -8,20 +8,27 @@ import torch.optim as optim
 from torch.optim import lr_scheduler
 import torch.backends.cudnn as cudnn
 import numpy as np
-import torchvision
 from torchvision import datasets, models, transforms
-from fastapi import FastAPI, File, UploadFile
-import uvicorn
-import json
 import time
 import os
+import zipfile
 import copy
 import platform
 from torch.utils.tensorboard import SummaryWriter
 
 cudnn.benchmark = True
 
-app = FastAPI()
+data_phase = ['train', 'val']
+database_dir = './data'
+output_dir = './runs'  # 模型保存和日志备份大目录
+newest_log = './newest_log'  # 最新日志保存目录
+log_port = 6667  # tensorboard日志端口
+writer: SummaryWriter
+
+# 用异步接收请求判断
+task_list = ['train', 'data_process']
+running_task = set()
+exit_flag = False
 
 
 def print_progress_bar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='█', print_end="\r"):
@@ -39,14 +46,10 @@ def print_progress_bar(iteration, total, prefix='', suffix='', decimals=1, lengt
 
 
 # 备份log
-def move_log(model_id):
-    if not os.path.exists(output_dir):
-        print('缺失路径: ', output_dir)
-        return
-
+def move_log(model_save_dir):
     log_name = os.listdir(newest_log)[0]
     log_path = os.path.join(newest_log, log_name)
-    save_path = os.path.join(output_dir, log_name)
+    save_path = os.path.join(model_save_dir, log_name)
 
     shutil.copy(log_path, save_path)
     print('log 已备份')
@@ -105,7 +108,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
             epoch_acc = running_corrects.double() / dataset_sizes[phase]
 
             writer.add_scalar(phase + " loss", epoch_loss, epoch + 1)
-            writer.add_scalar(phase + " accuracy", epoch_loss, epoch + 1)
+            writer.add_scalar(phase + " accuracy", epoch_acc, epoch + 1)
 
             print(f'{phase} Loss: {epoch_loss:.6f} Acc: {epoch_acc:.6f}')
 
@@ -127,65 +130,58 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
     return model
 
 
-def train(epoch=30, save_path='resnet.pth', load_my_model=False, model_path=None,
-          is_freeze=True, freeze_num=7, is_transfer_learn=False, transfer_cls=None):
+def train(epoch=30, save_path='resnet.pt', model_path=None,
+          freeze=7, learn_rate=0.01, momentum=0.9, decay=0.7):
     # 如果不加载训练过的模型则加载预训练模型
-    if load_my_model:
-        model = models.resnet50(pretrained=False)
-        num_features = model.fc.in_features
-        if is_transfer_learn:
-            # 加载旧模型后,更改为新模型的分类格式
-            model.fc = nn.Linear(num_features, transfer_cls)
-            model.load_state_dict(torch.load(model_path))
-
-            # 修改最后一层
-            num_features = model.fc.in_features
-            model.fc = nn.Linear(num_features, class_num)
-        else:
-            # 修改最后一层
-            model.fc = nn.Linear(num_features, class_num)
-            model.load_state_dict(torch.load(model_path))
-
-    else:
+    if model_path is None or model_path == '':
         model = models.resnet50(pretrained=True)
         # 修改最后一层
         num_features = model.fc.in_features
         model.fc = nn.Linear(num_features, class_num)
+    else:
+        model = torch.load(model_path)
+        old_cls_num = model.fc.out_features
+        if class_num == old_cls_num:
+            print('分类头适合, 进行训练')
+        else:
+            # 修改最后一层
+            num_features = model.fc.in_features
+            model.fc = nn.Linear(num_features, class_num)
+            print(f"修改分类头: {old_cls_num} --> {class_num}")
 
     model = model.to(device)
 
     criterion = nn.CrossEntropyLoss()
-    optimizer_ft = optim.SGD(model.parameters(), lr=0.0005, momentum=0.9)
+    optimizer_ft = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
     exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
 
     # 冻结部分参数
-    if is_freeze:
-        for i, c in enumerate(model.children()):
-            if i == freeze_num:
-                break
-            for param in c.parameters():
-                param.requires_grad = False
+    for i, c in enumerate(model.children()):
+        if i == freeze:
+            break
+        for param in c.parameters():
+            param.requires_grad = False
 
-        for param in model.parameters():
-            print(param.requires_grad)
+    for param in model.parameters():
+        print(param.requires_grad)
 
     model = train_model(model, criterion, optimizer_ft,
                         exp_lr_scheduler, num_epochs=epoch)
 
-    torch.save(model.state_dict(), save_path)
+    torch.save(model, save_path)
+
 
+'''
+    epoch: 训练次数
+    save_path: 模型保存路径
 
-@app.post("/train/params_json")
-async def upload_json(file: UploadFile = File(...)):
-    contents = await file.read()
-    json_data = contents.decode("utf-8")
-    # 处理JSON数据
-    print(json_data)
-    json_data = json.loads(json_data)
+    model_path: 加载的模型路径 
+    freeze_num: 冻结层数
 
-    global data_dir, device, class_num, dataloaders, dataset_sizes, output_dir, writer
+    '''
+def load_param(epoch, data_dir, model_path, freeze, learn_rate, momentum, decay):
+    global device, class_num, dataloaders, dataset_sizes, output_dir, writer
 
-    data_dir = json_data['data_dir_path']
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
     data_transforms = {
@@ -216,30 +212,12 @@ async def upload_json(file: UploadFile = File(...)):
     class_num = len(class_names)
     print('class:', class_num)
 
-    '''
-    epoch: 训练次数
-    save_path: 模型保存路径
-    load_my_model: 是否加载训练过的模型
-    model_path: 加载的模型路径 
-    is_freeze: 是否冻结模型部分层数
-    freeze_num: 冻结层数
-    is_transfer_learn: 是否迁移学习
-    transfer_cls: 迁移学习旧模型分类头数量
-    '''
-    # 数据集路径在本文件上面
-    epoch = json_data['params']['epoch']
-    load_my_model = json_data['load_my_model']
-    model_path = json_data['model_path']
-    is_transfer_learn = json_data['is_transfer_learn']
-    transfer_cls = json_data['transfer_cls']
-
-    if not os.path.exists(output_dir):
-        os.makedirs(output_dir)
+    print(f"输入参数: 训练次数: {epoch}, 模型路径: {model_path}")
 
     model_id = len(os.listdir(output_dir)) + 1
-    output_dir = os.path.join(output_dir, str(model_id))
-    if not os.path.exists(output_dir):
-        os.mkdir(output_dir)
+    model_save_dir = os.path.join(output_dir, str(model_id))
+    if not os.path.exists(model_save_dir):
+        os.mkdir(model_save_dir)
 
     # 删除旧的log
     if len(os.listdir(newest_log)) > 0:
@@ -247,30 +225,16 @@ async def upload_json(file: UploadFile = File(...)):
     writer = SummaryWriter(newest_log)
     writer.add_text('model', "model id: " + str(model_id))
 
-    save_path = os.path.join(output_dir, str(json_data['model_name']) + '_out' + str(class_num) + '.pth')
+    save_path = os.path.join(model_save_dir, 'resnet50_out' + str(class_num) + '.pt')
 
-    train(epoch=epoch, save_path=save_path,
-          load_my_model=load_my_model,
+    train(epoch=epoch,
+          save_path=save_path,
           model_path=model_path,
-          is_transfer_learn=is_transfer_learn, transfer_cls=transfer_cls
-          )
+          freeze=freeze,
+          learn_rate=learn_rate,
+          momentum=momentum,
+          decay=decay)
+
     writer.flush()
     writer.close()
-    move_log(model_id)
-
-    return {"train end"}
-
-
-if __name__ == '__main__':
-    data_dir: str
-    device = None
-    class_num: int
-    dataloaders = None
-    dataset_sizes = None
-
-    data_phase = ['train', 'val']
-    output_dir = './runs'
-    newest_log = './newest_log'
-    writer: SummaryWriter
-
-    uvicorn.run(app, host='0.0.0.0', port=6661)
+    move_log(model_save_dir)

+ 24 - 0
utils/preview_img.py

@@ -0,0 +1,24 @@
+import os
+from Config import Config
+
+config = Config()
+
+def get_img_lits(img_dir):
+    imgs_List = [os.path.join(img_dir, name) for name in sorted(os.listdir(img_dir)[:10]) if
+                 name.endswith(('.png', '.jpg', '.webp', '.tif', '.jpeg', '.JPG', '.PNG'))]
+    return imgs_List
+
+
+def preview_img_dir(train_dir_name):
+    print('预览: ', train_dir_name)
+
+    train_dir = os.path.join(config.data_dir, train_dir_name, 'train')
+    img_dir = os.path.join(train_dir, os.listdir(train_dir)[0])
+
+    img_paths_list = get_img_lits(img_dir)  # 注意传入自定义的web
+    # 结果为 list,里面对象可以为
+    dict_path = []
+    for i in range(len(img_paths_list)):
+        dict_path.append((img_paths_list[i], 'img_descrip' + str(i)))  # 图片路径,图片描述, 图片描述可以自定义字符串
+    print(dict_path)
+    return dict_path