AnlaAnla 1 年之前
當前提交
b17c45ef0a

+ 8 - 0
.idea/.gitignore

@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# 基于编辑器的 HTTP 客户端请求
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml

+ 29 - 0
.idea/deployment.xml

@@ -0,0 +1,29 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="PublishConfigData" autoUpload="Always" serverName="martin@192.168.56.116:22 password" remoteFilesAllowedToDisappearOnAutoupload="false">
+    <serverData>
+      <paths name="martin@192.168.56.116:22 password">
+        <serverdata>
+          <mappings>
+            <mapping deploy="/media/martin/DATA/_ML/RemoteProject/train_model_server" local="$PROJECT_DIR$" web="/" />
+          </mappings>
+        </serverdata>
+      </paths>
+      <paths name="pi@192.168.56.156:22 password">
+        <serverdata>
+          <mappings>
+            <mapping local="$PROJECT_DIR$" web="/" />
+          </mappings>
+        </serverdata>
+      </paths>
+      <paths name="pi@192.168.56.156:22 password (2)">
+        <serverdata>
+          <mappings>
+            <mapping local="$PROJECT_DIR$" web="/" />
+          </mappings>
+        </serverdata>
+      </paths>
+    </serverData>
+    <option name="myAutoUpload" value="ALWAYS" />
+  </component>
+</project>

+ 35 - 0
.idea/inspectionProfiles/Project_Default.xml

@@ -0,0 +1,35 @@
+<component name="InspectionProjectProfileManager">
+  <profile version="1.0">
+    <option name="myName" value="Project Default" />
+    <inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
+    <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
+      <option name="ignoredPackages">
+        <value>
+          <list size="21">
+            <item index="0" class="java.lang.String" itemvalue="webargs" />
+            <item index="1" class="java.lang.String" itemvalue="transformers" />
+            <item index="2" class="java.lang.String" itemvalue="timm" />
+            <item index="3" class="java.lang.String" itemvalue="fluent-logger" />
+            <item index="4" class="java.lang.String" itemvalue="towhee" />
+            <item index="5" class="java.lang.String" itemvalue="flask_restful" />
+            <item index="6" class="java.lang.String" itemvalue="opencv_python" />
+            <item index="7" class="java.lang.String" itemvalue="fastapi" />
+            <item index="8" class="java.lang.String" itemvalue="seaborn" />
+            <item index="9" class="java.lang.String" itemvalue="matplotlib" />
+            <item index="10" class="java.lang.String" itemvalue="minio" />
+            <item index="11" class="java.lang.String" itemvalue="ipython" />
+            <item index="12" class="java.lang.String" itemvalue="torch" />
+            <item index="13" class="java.lang.String" itemvalue="uvicorn" />
+            <item index="14" class="java.lang.String" itemvalue="python-multipart" />
+            <item index="15" class="java.lang.String" itemvalue="torchvision" />
+            <item index="16" class="java.lang.String" itemvalue="pymilvus" />
+            <item index="17" class="java.lang.String" itemvalue="psutil" />
+            <item index="18" class="java.lang.String" itemvalue="ultralytics" />
+            <item index="19" class="java.lang.String" itemvalue="picamera2" />
+            <item index="20" class="java.lang.String" itemvalue="posix_ipc" />
+          </list>
+        </value>
+      </option>
+    </inspection_tool>
+  </profile>
+</component>

+ 6 - 0
.idea/inspectionProfiles/profiles_settings.xml

@@ -0,0 +1,6 @@
+<component name="InspectionProjectProfileManager">
+  <settings>
+    <option name="USE_PROJECT_PROFILE" value="false" />
+    <version value="1.0" />
+  </settings>
+</component>

+ 10 - 0
.idea/misc.xml

@@ -0,0 +1,10 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="Black">
+    <option name="sdkName" value="yolov8" />
+  </component>
+  <component name="ProjectRootManager" version="2" project-jdk-name="yolov8" project-jdk-type="Python SDK" />
+  <component name="PythonCompatibilityInspectionAdvertiser">
+    <option name="version" value="3" />
+  </component>
+</project>

+ 8 - 0
.idea/modules.xml

@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectModuleManager">
+    <modules>
+      <module fileurl="file://$PROJECT_DIR$/.idea/train_model_server.iml" filepath="$PROJECT_DIR$/.idea/train_model_server.iml" />
+    </modules>
+  </component>
+</project>

+ 8 - 0
.idea/train_model_server.iml

@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<module type="PYTHON_MODULE" version="4">
+  <component name="NewModuleRootManager">
+    <content url="file://$MODULE_DIR$" />
+    <orderEntry type="jdk" jdkName="yolov8" jdkType="Python SDK" />
+    <orderEntry type="sourceFolder" forTests="false" />
+  </component>
+</module>

+ 6 - 0
.idea/vcs.xml

@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="VcsDirectoryMappings">
+    <mapping directory="$PROJECT_DIR$" vcs="Git" />
+  </component>
+</project>

+ 38 - 0
README.MD

@@ -0,0 +1,38 @@
+## train 0.0.0.0:port=6661
+###  运行train.py后, 等待传递参数进行训练
+```python
+# 发送请求方法
+import json
+import requests
+
+with open('train_params.json', 'r', encoding='utf-8') as f:
+    data = f.read()
+url = "http://192.168.56.116:6661/train/params_json"
+response = requests.post(url, files={"file": data})
+
+print(response)
+```
+
+## 查询训练过程  0.0.0.0:port=6662
+### 服务器运行命令
+```bash
+tensorboard --logdir=runs --host=0.0.0.0 --port 6662
+```
+### 客户端访问 http://0.0.0.0:6662
+
+
+## 客户端请求服务器数据
+```bash
+上传文件:
+curl -X POST -F "file=@/path/to/your/file.zip" http://localhost:6664/upload
+
+下载文件:
+curl -X GET http://localhost:6664/download/file.zip -o downloaded_file.zip
+
+删除文件:
+curl -X POST http://localhost:6664/delete/file.zip
+
+清空 uploads 目录:
+curl -X POST http://localhost:6664/clear
+
+```

+ 92 - 0
data_transmit.py

@@ -0,0 +1,92 @@
+from fastapi import FastAPI, File, UploadFile
+from fastapi.responses import FileResponse
+import os
+import shutil
+import zipfile
+
+app = FastAPI()
+
+# 存储上传的文件的目录
+DATA_FOLDER = 'data'
+MODEL_FOLDER = 'runs'
+if not os.path.exists(DATA_FOLDER):
+    os.makedirs(DATA_FOLDER)
+
+if not os.path.exists(MODEL_FOLDER):
+    os.makedirs(MODEL_FOLDER)
+
+
+@app.post("/upload")
+async def upload_file(file: UploadFile):
+    data_dir = os.path.join(DATA_FOLDER, os.path.splitext(file.filename)[0])
+
+    # 排除名称重复的数据
+    if os.path.exists(data_dir):
+        return {"error": "存在重名数据"}
+
+    # 检查文件是否为 ZIP 格式
+    if file.filename.endswith('.zip'):
+        # 保存 ZIP 文件到 uploads 目录
+        zip_file_path = os.path.join(DATA_FOLDER, file.filename)
+        with open(zip_file_path, "wb") as f:
+            f.write(await file.read())
+
+        # 解压 ZIP 文件
+        with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
+
+            os.makedirs(data_dir, exist_ok=True)
+            zip_ref.extractall(data_dir)
+
+        # 删除原始 ZIP 文件
+        os.remove(zip_file_path)
+
+        return {"message": f"数据 '{file.filename}' 上传成功"}
+    else:
+        return {"error": "我只要zip文件"}
+
+
+@app.get("/download/{model_id}")
+async def download_file(model_id: str):
+    # 从 runs 目录下载指定id的模型文件
+    folder_path = os.path.join(DATA_FOLDER, model_id)
+    if not os.path.exists(folder_path):
+        return {f"{folder_path} 不存在"}
+
+    model_name = [name for name in os.listdir(folder_path) if '.pth' in name][0]
+    model_path = os.path.join(folder_path, model_name)
+
+    return FileResponse(model_path, media_type='application/octet-stream', filename=model_name)
+
+
+# @app.post("/delete/{filename}")
+# async def delete_file(filename: str):
+#     # 删除 uploads 目录下的文件
+#     file_path = os.path.join(DATA_FOLDER, filename)
+#     os.remove(file_path)
+#     return {"message": f"文件 '{filename}' 成功删除"}
+
+@app.post("/check/{dir_name}")
+async def clear_uploads(dir_name: str):
+    # 查询 dir_name 目录
+    if dir_name == 'data':
+        data = os.listdir(DATA_FOLDER)
+    else:
+        data = os.listdir(MODEL_FOLDER)
+    return data
+
+
+@app.post("/clear/{dir_name}")
+async def clear_uploads(dir_name: str):
+    # 清空 dir_name 目录
+    if dir_name == 'data':
+        shutil.rmtree(DATA_FOLDER)
+        os.makedirs(DATA_FOLDER)
+    elif dir_name == 'models':
+        shutil.rmtree(MODEL_FOLDER)
+    return {dir_name, " 目录已经清空"}
+
+
+if __name__ == "__main__":
+    import uvicorn
+
+    uvicorn.run(app, host="0.0.0.0", port=6664)

+ 12 - 0
show_server.py

@@ -0,0 +1,12 @@
+import os
+
+
+if __name__ == '__main__':
+    folder_path = './newest_log'
+    if not os.path.exists(folder_path):
+        os.mkdir(folder_path)
+
+    if len(os.listdir(folder_path)) == 0:
+        print('没有 log')
+    else:
+        print(os.system(f'tensorboard --logdir={folder_path} --host=0.0.0.0 --port 6662'))

+ 276 - 0
train.py

@@ -0,0 +1,276 @@
+from __future__ import print_function, division
+
+import shutil
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.optim import lr_scheduler
+import torch.backends.cudnn as cudnn
+import numpy as np
+import torchvision
+from torchvision import datasets, models, transforms
+from fastapi import FastAPI, File, UploadFile
+import uvicorn
+import json
+import time
+import os
+import copy
+import platform
+from torch.utils.tensorboard import SummaryWriter
+
+cudnn.benchmark = True
+
+app = FastAPI()
+
+
+def print_progress_bar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='█', print_end="\r"):
+    # 计算完成百分比
+    percent_complete = f"{(100 * (iteration / float(total))):.{decimals}f}"
+    # 计算进度条填充长度
+    filled_length = int(length * iteration // total)
+    # 创建进度条字符串
+    bar = fill * filled_length + '-' * (length - filled_length)
+    # 打印进度条
+    print(f'\r{prefix} |{bar}| {percent_complete}% {suffix}', end=print_end)
+    # 完成时打印新行
+    if iteration == total:
+        print()
+
+
+# 备份log
+def move_log(model_id):
+    if not os.path.exists(output_dir):
+        print('缺失路径: ', output_dir)
+        return
+
+    log_name = os.listdir(newest_log)[0]
+    log_path = os.path.join(newest_log, log_name)
+    save_path = os.path.join(output_dir, log_name)
+
+    shutil.copy(log_path, save_path)
+    print('log 已备份')
+
+
+def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
+    since = time.time()
+
+    best_model_wts = copy.deepcopy(model.state_dict())
+    best_acc = 0.0
+
+    for epoch in range(num_epochs):
+        print(f'Epoch {epoch + 1}/{num_epochs}')
+        print('-' * 10)
+
+        for phase in data_phase:
+            if phase == 'train':
+                model.train()
+            else:
+                model.eval()
+
+            running_loss = 0.0
+            running_corrects = 0
+
+            l = len(dataloaders[phase])
+            print_progress_bar(0, l, prefix='进度:', suffix='完成', length=50)
+
+            # Iterate over data.
+            for i, (inputs, labels) in enumerate(dataloaders[phase]):
+                inputs = inputs.to(device)
+                labels = labels.to(device)
+
+                optimizer.zero_grad()
+
+                with torch.set_grad_enabled(phase == 'train'):
+                    outputs = model(inputs)
+                    _, preds = torch.max(outputs, 1)
+                    loss = criterion(outputs, labels)
+
+                    # backward + optimize only if in training phase
+                    if phase == 'train':
+                        loss.backward()
+                        optimizer.step()
+
+                # statistics
+                running_loss += loss.item() * inputs.size(0)
+                running_corrects += torch.sum(preds == labels.data)
+
+                # 更新进度条
+                print_progress_bar(i + 1, l, prefix='进度:', suffix='完成', length=50)
+
+            if phase == 'train':
+                scheduler.step()
+
+            epoch_loss = running_loss / dataset_sizes[phase]
+            epoch_acc = running_corrects.double() / dataset_sizes[phase]
+
+            writer.add_scalar(phase + " loss", epoch_loss, epoch + 1)
+            writer.add_scalar(phase + " accuracy", epoch_loss, epoch + 1)
+
+            print(f'{phase} Loss: {epoch_loss:.6f} Acc: {epoch_acc:.6f}')
+
+            # deep copy the model
+            if phase == 'val' and epoch_acc >= best_acc:
+                best_acc = epoch_acc
+                best_model_wts = copy.deepcopy(model.state_dict())
+                print('copy best model')
+
+        print()
+
+    time_elapsed = time.time() - since
+    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
+    print(f'Best val Acc: {best_acc:4f}')
+
+    # 加载在验证集中表现最好的模型
+    model.load_state_dict(best_model_wts)
+
+    return model
+
+
+def train(epoch=30, save_path='resnet.pth', load_my_model=False, model_path=None,
+          is_freeze=True, freeze_num=7, is_transfer_learn=False, transfer_cls=None):
+    # 如果不加载训练过的模型则加载预训练模型
+    if load_my_model:
+        model = models.resnet50(pretrained=False)
+        num_features = model.fc.in_features
+        if is_transfer_learn:
+            # 加载旧模型后,更改为新模型的分类格式
+            model.fc = nn.Linear(num_features, transfer_cls)
+            model.load_state_dict(torch.load(model_path))
+
+            # 修改最后一层
+            num_features = model.fc.in_features
+            model.fc = nn.Linear(num_features, class_num)
+        else:
+            # 修改最后一层
+            model.fc = nn.Linear(num_features, class_num)
+            model.load_state_dict(torch.load(model_path))
+
+    else:
+        model = models.resnet50(pretrained=True)
+        # 修改最后一层
+        num_features = model.fc.in_features
+        model.fc = nn.Linear(num_features, class_num)
+
+    model = model.to(device)
+
+    criterion = nn.CrossEntropyLoss()
+    optimizer_ft = optim.SGD(model.parameters(), lr=0.0005, momentum=0.9)
+    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
+
+    # 冻结部分参数
+    if is_freeze:
+        for i, c in enumerate(model.children()):
+            if i == freeze_num:
+                break
+            for param in c.parameters():
+                param.requires_grad = False
+
+        for param in model.parameters():
+            print(param.requires_grad)
+
+    model = train_model(model, criterion, optimizer_ft,
+                        exp_lr_scheduler, num_epochs=epoch)
+
+    torch.save(model.state_dict(), save_path)
+
+
+@app.post("/train/params_json")
+async def upload_json(file: UploadFile = File(...)):
+    contents = await file.read()
+    json_data = contents.decode("utf-8")
+    # 处理JSON数据
+    print(json_data)
+    json_data = json.loads(json_data)
+
+    global data_dir, device, class_num, dataloaders, dataset_sizes, output_dir, writer
+
+    data_dir = json_data['data_dir_path']
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+    data_transforms = {
+        'train': transforms.Compose([
+            transforms.Resize((224, 224)),
+            transforms.RandAugment(),
+            transforms.RandomPerspective(distortion_scale=0.5, p=0.5),
+            transforms.ColorJitter(brightness=0.2, contrast=0.2),
+            transforms.RandomApply([transforms.GaussianBlur(5)], p=0.3),
+            transforms.ToTensor(),
+            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+        ]),
+        'val': transforms.Compose([
+            transforms.Resize((224, 224)),
+            transforms.ToTensor(),
+            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+        ]),
+    }
+
+    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
+                      for x in data_phase}
+    dataloaders = {
+        x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=16, pin_memory=True)
+        for x in data_phase}
+    dataset_sizes = {x: len(image_datasets[x]) for x in data_phase}
+    class_names = image_datasets['train'].classes
+
+    class_num = len(class_names)
+    print('class:', class_num)
+
+    '''
+    epoch: 训练次数
+    save_path: 模型保存路径
+    load_my_model: 是否加载训练过的模型
+    model_path: 加载的模型路径 
+    is_freeze: 是否冻结模型部分层数
+    freeze_num: 冻结层数
+    is_transfer_learn: 是否迁移学习
+    transfer_cls: 迁移学习旧模型分类头数量
+    '''
+    # 数据集路径在本文件上面
+    epoch = json_data['params']['epoch']
+    load_my_model = json_data['load_my_model']
+    model_path = json_data['model_path']
+    is_transfer_learn = json_data['is_transfer_learn']
+    transfer_cls = json_data['transfer_cls']
+
+    if not os.path.exists(output_dir):
+        os.makedirs(output_dir)
+
+    model_id = len(os.listdir(output_dir)) + 1
+    output_dir = os.path.join(output_dir, str(model_id))
+    if not os.path.exists(output_dir):
+        os.mkdir(output_dir)
+
+    # 删除旧的log
+    if len(os.listdir(newest_log)) > 0:
+        os.remove(os.path.join(newest_log, os.listdir(newest_log)[0]))
+    writer = SummaryWriter(newest_log)
+    writer.add_text('model', "model id: " + str(model_id))
+
+    save_path = os.path.join(output_dir, str(json_data['model_name']) + '_out' + str(class_num) + '.pth')
+
+    train(epoch=epoch, save_path=save_path,
+          load_my_model=load_my_model,
+          model_path=model_path,
+          is_transfer_learn=is_transfer_learn, transfer_cls=transfer_cls
+          )
+    writer.flush()
+    writer.close()
+    move_log(model_id)
+
+    return {"train end"}
+
+
+if __name__ == '__main__':
+    data_dir: str
+    device = None
+    class_num: int
+    dataloaders = None
+    dataset_sizes = None
+
+    data_phase = ['train', 'val']
+    output_dir = './runs'
+    newest_log = './newest_log'
+    writer: SummaryWriter
+
+    uvicorn.run(app, host='0.0.0.0', port=6661)

+ 17 - 0
train_params.json

@@ -0,0 +1,17 @@
+{
+  "model_name": "resnet50",
+  "data_dir_path": "data/data_test",
+
+  "load_my_model": false,
+  "model_path": "",
+
+  "is_transfer_learn": false,
+  "transfer_cls": 12796,
+
+  "params":{
+    "epoch": 30,
+    "learn_rate": 0.01,
+    "momentum": 0.9,
+    "decay": 0.7
+  }
+}