Procházet zdrojové kódy

Merge branch 'dev-python'

rambo před 11 měsíci
rodič
revize
5e1fee4db8

+ 4 - 2
python/models.py

@@ -50,9 +50,11 @@ class ModelFormModel(BaseModel):
     size_mode: int = Field(default=0, description="尺寸模式;0=>指定大小;1=>最小边框")
     output_mode: int = Field(default=0, description="输出模式;0=>透明底;1=>白底图")
     need_cutout_images: list = Field(default=None, description="图像地址集合")
-    
+
 class SearchProgress(BaseModel):
     '''进度查询'''
     token: str = Field(default=None, description="hlm token信息")
     generate_ids:list[int] = Field(default=[], description="生成记录ID数组")
-    type: str = Field(default="aigc_pro", description="进度类型")
+    type: str = Field(default="aigc_pro", description="进度类型")
+    result: list = Field(default=None, description="图像地址集合")
+    save_root_path: str = Field(default="", description="保存图像地址")

+ 2 - 2
python/services/SegmentService.py

@@ -55,13 +55,13 @@ class SegmentService:
                 raise UnicornException("目录下文件过多,请检查目录是否正确")
             file_path = "{}/{}".format(root_path, file)
             if os.path.isdir(file_path):
-                print(file_path)
+                print("file_path", file_path, file)
                 if file == "已扣图":
                     # 哪些图片已经有抠图
                     for x_file in os.listdir(file_path):
                         x_file_name, x_file_e = os.path.splitext(x_file)
                         if x_file_e == ".png":
-                            continue
+                            # continue
                             _is_cutout.append(x_file_name)
 
         # ===============================================================

+ 96 - 64
python/services/deal_cutout.py

@@ -1,25 +1,31 @@
 import time
 from concurrent.futures import as_completed, ThreadPoolExecutor, wait
 import threading
-from .remove_bg_pixian import RemoveBgPiXian,Picture
+from .remove_bg_pixian import RemoveBgPiXian, Picture
 from .other.module_online_data import GetOnlineData
 from .deal_one_image import DealOneImage, DealOneImageBeforehand
 from .other.log import MyLogger
-from models import  UnicornException
+from models import UnicornException
 import pandas as pd
 import csv
 from PIL import Image
 from io import BytesIO
-import os,requests,io
+import os, requests, io
+
+
 def urlPilImage(url):
     yzmdata = requests.get(url)
     tempIm = BytesIO(yzmdata.content)
     im = Image.open(tempIm)
     return im
+
+
 def check_path(_path):
-        if not os.path.exists(_path):
-            os.mkdir(_path)
-        return True
+    if not os.path.exists(_path):
+        os.mkdir(_path)
+    return True
+
+
 class DealCutout:
 
     def __init__(self, token):
@@ -70,6 +76,8 @@ class DealCutout:
         print("self.need_cutout_images", self.need_cutout_images)
         save_root_path = ""
         for image_data in self.need_cutout_images:
+            if image_data["need_cutout"] == False:
+                continue
             num += 1
             save_root_path = image_data["root_path"]
             result = DealOneImageBeforehand(
@@ -105,6 +113,8 @@ class DealCloths:
         result_array = []
         save_root_path = ""
         for image_data in self.need_cutout_images:
+            if image_data["need_cutout"] == False:
+                continue
             num += 1
             save_root_path = image_data["root_path"]
             upload_pic_dict = {}
@@ -119,10 +129,13 @@ class DealCloths:
             result_array.append(upload_pic_dict)
         return result_array, save_root_path
 
+
 class DealModelForm:
-    '''处理人台抠图'''
-    csvName = 'record.csv'
-    def __init__(self, token,params):
+    """处理人台抠图"""
+
+    csvName = "record.csv"
+
+    def __init__(self, token, params):
         super().__init__()
         self.lock = threading.Lock()
         self.need_cutout_images = {}
@@ -136,23 +149,37 @@ class DealModelForm:
         # 图片列表
         self.upload_pic_dict = {}
         self.logger = MyLogger().logger
-    def addData2Csv(self,data):
-        name_list = ['file_name', 'file_e', 'file_path', 'file','root_path','need_cutout','image_url','generate_id','status'] 
+
+    def addData2Csv(self, data):
+        name_list = [
+            "file_name",
+            "file_e",
+            "file_path",
+            "file",
+            "root_path",
+            "need_cutout",
+            "image_url",
+            "generate_id",
+            "status",
+        ]
         isExist = os.path.exists(self.csvName)
-        csvfile = open(self.csvName,"a", encoding='utf-8-sig')
+        csvfile = open(self.csvName, "a", encoding="utf-8-sig")
         writer = csv.writer(csvfile)
-            #先写入columns_name
+        # 先写入columns_name
         if isExist == False:
-                writer.writerow(name_list)
+            writer.writerow(name_list)
         writer.writerows(data)
         csvfile.close()
+
     def startDispose(self):
         self.get_online_data.refresh_headers()
         num = 0
         save_root_path = ""
         baseImages = []
-        resize=1600#定义标准
-        for index,image_data in enumerate(self.need_cutout_images):
+        resize = 1600  # 定义标准
+        for index, image_data in enumerate(self.need_cutout_images):
+            if image_data["need_cutout"] == False:
+                continue
             num += 1
             save_root_path = image_data["root_path"]
             file_path = image_data["file_path"]
@@ -164,86 +191,91 @@ class DealModelForm:
                 if original_pic.y > resize:
                     original_pic.resize_by_heigh(heigh=resize)
             buffer = io.BytesIO()
-            original_pic.im.save(buffer, format='JPEG')
+            original_pic.im.save(buffer, format="JPEG")
             buffer.seek(0)
-            image_url = self.get_online_data.upload_pic(file_path=None,buffer=buffer)
+            image_url = self.get_online_data.upload_pic(file_path=None, buffer=buffer)
             baseImages.append(image_url)
             self.need_cutout_images[index]["image_url"] = image_url
         data = {
-            "base_image":baseImages,
-            "out_width":self.params.out_width,
-            "out_height":self.params.out_height,
-            "size_mode":self.params.size_mode,
-            "output_mode":self.params.output_mode,
+            "base_image": baseImages,
+            "out_width": self.params.out_width,
+            "out_height": self.params.out_height,
+            "size_mode": self.params.size_mode,
+            "output_mode": self.params.output_mode,
         }
         result_json = self.get_online_data.model_form_segment(data)
         generate_ids = result_json.get("generate_ids")
         saveParams = []
-        for idx,id in enumerate(generate_ids):
-            self.need_cutout_images[idx]['generate_id'] = id
-            # ['file_name', 'file_e', 'file_path', 
-            # 'file','root_path','need_cutout','image_url','generate_id','status'] 
+        for idx, id in enumerate(generate_ids):
+            if self.need_cutout_images[idx]["need_cutout"] == False:
+                continue
+            self.need_cutout_images[idx]["generate_id"] = id
+            # ['file_name', 'file_e', 'file_path',
+            # 'file','root_path','need_cutout','image_url','generate_id','status']
             item = self.need_cutout_images[idx]
-            saveParams.append([item['file_name'],
-                               item['file_e'],
-                               item['file_path'],
-                               item['file'],
-                               item['root_path'],
-                               item['need_cutout'],
-                               item['image_url'],
-                               item['generate_id'],
-                               False,
-                               ])
-        self.addData2Csv(saveParams)
-        return self.need_cutout_images, save_root_path,generate_ids
+            saveParams.append(
+                [
+                    item["file_name"],
+                    item["file_e"],
+                    item["file_path"],
+                    item["file"],
+                    item["root_path"],
+                    item["need_cutout"],
+                    item["image_url"],
+                    item["generate_id"],
+                    False,
+                ]
+            )
+        # self.addData2Csv(saveParams)
+        return self.need_cutout_images, save_root_path, generate_ids
+
     def search_progress(self):
-        try:
-          csvData = pd.read_csv(self.csvName)
-        except FileNotFoundError as e:
-          raise UnicornException("不存在生成记录,请先提交抠人台抠图任务")
-        '''进度查询'''
-        print("self.params",self.params)
+        # try:
+        #     csvData = pd.read_csv(self.csvName)
+        # except FileNotFoundError as e:
+        #     raise UnicornException("不存在生成记录,请先提交抠人台抠图任务")
+        """进度查询"""
+        print("self.params", self.params)
         search_generate_ids = self.params.generate_ids
         dataParams = {
-            "generate_ids":search_generate_ids,
-            "type":self.params.type,
+            "generate_ids": search_generate_ids,
+            "type": self.params.type,
         }
         responseData = self.get_online_data.search_progress(dataParams)
-        generate_ids = csvData.loc[csvData['generate_id'].isin(search_generate_ids)]
+        generate_ids = self.params.result
         successCount = 0
         failCount = 0
         is_finished = False
         root_path = ""
         if len(generate_ids) > 0:
-            print("generate_ids",generate_ids)
-            root_path = generate_ids.iloc[0]["root_path"]
+            print("generate_ids", generate_ids)
+            root_path = generate_ids[0]["root_path"]
             save_path = f"{root_path}/已扣图"
             check_path(save_path)
         else:
-            return is_finished,successCount,failCount,root_path
-        for idx,generate in generate_ids.iterrows():
+            return is_finished, successCount, failCount, root_path
+        for idx, generate in enumerate(generate_ids):
             for respItem in responseData:
-                if generate["generate_id"]!=respItem["id"]:
+                if generate["generate_id"] != respItem["id"]:
                     continue
-                status = respItem['status']
-                print("status",status)
+                status = respItem["status"]
+                print("status", status)
                 if status == -1:
-                    failCount+=1
-                    csvData.drop(csvData.loc[csvData['generate_id'] == generate["generate_id"]].index,inplace=True)
+                    failCount += 1
                     break
                 if status == 2:
-                    successCount+=1
-                    result_image_url = respItem['result_image_urls'][0]
+                    successCount += 1
+                    result_image_url = respItem["result_image_urls"][0]
                     result_image_pil = urlPilImage(result_image_url)
                     root_path = generate["root_path"]
                     file_name = generate["file_name"]
                     file_e = generate["file_e"]
-                    if result_image_pil.mode == 'RGBA':
+                    if result_image_pil.mode == "RGBA":
                         result_image_pil.save(f"{save_path}/{file_name}.png")
                     else:
                         result_image_pil.save(f"{save_path}/{file_name}.jpg")
-                    csvData.drop(csvData.loc[csvData['generate_id'] == generate["generate_id"]].index,inplace=True)
                     break
-        csvData.to_csv(self.csvName, index=False)
-        is_finished = True if len(search_generate_ids) == (successCount+failCount) else False
-        return is_finished,successCount,failCount,root_path
+        is_finished = (
+            True if len(search_generate_ids) == (successCount + failCount) else False
+        )
+        return is_finished, successCount, failCount, root_path