浏览代码

Merge branch 'dev-python'

rambo 11 月之前
父节点
当前提交
5e1fee4db8
共有 3 个文件被更改,包括 102 次插入68 次删除
  1. 4 2
      python/models.py
  2. 2 2
      python/services/SegmentService.py
  3. 96 64
      python/services/deal_cutout.py

+ 4 - 2
python/models.py

@@ -50,9 +50,11 @@ class ModelFormModel(BaseModel):
     size_mode: int = Field(default=0, description="尺寸模式;0=>指定大小;1=>最小边框")
     size_mode: int = Field(default=0, description="尺寸模式;0=>指定大小;1=>最小边框")
     output_mode: int = Field(default=0, description="输出模式;0=>透明底;1=>白底图")
     output_mode: int = Field(default=0, description="输出模式;0=>透明底;1=>白底图")
     need_cutout_images: list = Field(default=None, description="图像地址集合")
     need_cutout_images: list = Field(default=None, description="图像地址集合")
-    
+
 class SearchProgress(BaseModel):
 class SearchProgress(BaseModel):
     '''进度查询'''
     '''进度查询'''
     token: str = Field(default=None, description="hlm token信息")
     token: str = Field(default=None, description="hlm token信息")
     generate_ids:list[int] = Field(default=[], description="生成记录ID数组")
     generate_ids:list[int] = Field(default=[], description="生成记录ID数组")
-    type: str = Field(default="aigc_pro", description="进度类型")
+    type: str = Field(default="aigc_pro", description="进度类型")
+    result: list = Field(default=None, description="图像地址集合")
+    save_root_path: str = Field(default="", description="保存图像地址")

+ 2 - 2
python/services/SegmentService.py

@@ -55,13 +55,13 @@ class SegmentService:
                 raise UnicornException("目录下文件过多,请检查目录是否正确")
                 raise UnicornException("目录下文件过多,请检查目录是否正确")
             file_path = "{}/{}".format(root_path, file)
             file_path = "{}/{}".format(root_path, file)
             if os.path.isdir(file_path):
             if os.path.isdir(file_path):
-                print(file_path)
+                print("file_path", file_path, file)
                 if file == "已扣图":
                 if file == "已扣图":
                     # 哪些图片已经有抠图
                     # 哪些图片已经有抠图
                     for x_file in os.listdir(file_path):
                     for x_file in os.listdir(file_path):
                         x_file_name, x_file_e = os.path.splitext(x_file)
                         x_file_name, x_file_e = os.path.splitext(x_file)
                         if x_file_e == ".png":
                         if x_file_e == ".png":
-                            continue
+                            # continue
                             _is_cutout.append(x_file_name)
                             _is_cutout.append(x_file_name)
 
 
         # ===============================================================
         # ===============================================================

+ 96 - 64
python/services/deal_cutout.py

@@ -1,25 +1,31 @@
 import time
 import time
 from concurrent.futures import as_completed, ThreadPoolExecutor, wait
 from concurrent.futures import as_completed, ThreadPoolExecutor, wait
 import threading
 import threading
-from .remove_bg_pixian import RemoveBgPiXian,Picture
+from .remove_bg_pixian import RemoveBgPiXian, Picture
 from .other.module_online_data import GetOnlineData
 from .other.module_online_data import GetOnlineData
 from .deal_one_image import DealOneImage, DealOneImageBeforehand
 from .deal_one_image import DealOneImage, DealOneImageBeforehand
 from .other.log import MyLogger
 from .other.log import MyLogger
-from models import  UnicornException
+from models import UnicornException
 import pandas as pd
 import pandas as pd
 import csv
 import csv
 from PIL import Image
 from PIL import Image
 from io import BytesIO
 from io import BytesIO
-import os,requests,io
+import os, requests, io
+
+
 def urlPilImage(url):
 def urlPilImage(url):
     yzmdata = requests.get(url)
     yzmdata = requests.get(url)
     tempIm = BytesIO(yzmdata.content)
     tempIm = BytesIO(yzmdata.content)
     im = Image.open(tempIm)
     im = Image.open(tempIm)
     return im
     return im
+
+
 def check_path(_path):
 def check_path(_path):
-        if not os.path.exists(_path):
-            os.mkdir(_path)
-        return True
+    if not os.path.exists(_path):
+        os.mkdir(_path)
+    return True
+
+
 class DealCutout:
 class DealCutout:
 
 
     def __init__(self, token):
     def __init__(self, token):
@@ -70,6 +76,8 @@ class DealCutout:
         print("self.need_cutout_images", self.need_cutout_images)
         print("self.need_cutout_images", self.need_cutout_images)
         save_root_path = ""
         save_root_path = ""
         for image_data in self.need_cutout_images:
         for image_data in self.need_cutout_images:
+            if image_data["need_cutout"] == False:
+                continue
             num += 1
             num += 1
             save_root_path = image_data["root_path"]
             save_root_path = image_data["root_path"]
             result = DealOneImageBeforehand(
             result = DealOneImageBeforehand(
@@ -105,6 +113,8 @@ class DealCloths:
         result_array = []
         result_array = []
         save_root_path = ""
         save_root_path = ""
         for image_data in self.need_cutout_images:
         for image_data in self.need_cutout_images:
+            if image_data["need_cutout"] == False:
+                continue
             num += 1
             num += 1
             save_root_path = image_data["root_path"]
             save_root_path = image_data["root_path"]
             upload_pic_dict = {}
             upload_pic_dict = {}
@@ -119,10 +129,13 @@ class DealCloths:
             result_array.append(upload_pic_dict)
             result_array.append(upload_pic_dict)
         return result_array, save_root_path
         return result_array, save_root_path
 
 
+
 class DealModelForm:
 class DealModelForm:
-    '''处理人台抠图'''
-    csvName = 'record.csv'
-    def __init__(self, token,params):
+    """处理人台抠图"""
+
+    csvName = "record.csv"
+
+    def __init__(self, token, params):
         super().__init__()
         super().__init__()
         self.lock = threading.Lock()
         self.lock = threading.Lock()
         self.need_cutout_images = {}
         self.need_cutout_images = {}
@@ -136,23 +149,37 @@ class DealModelForm:
         # 图片列表
         # 图片列表
         self.upload_pic_dict = {}
         self.upload_pic_dict = {}
         self.logger = MyLogger().logger
         self.logger = MyLogger().logger
-    def addData2Csv(self,data):
-        name_list = ['file_name', 'file_e', 'file_path', 'file','root_path','need_cutout','image_url','generate_id','status'] 
+
+    def addData2Csv(self, data):
+        name_list = [
+            "file_name",
+            "file_e",
+            "file_path",
+            "file",
+            "root_path",
+            "need_cutout",
+            "image_url",
+            "generate_id",
+            "status",
+        ]
         isExist = os.path.exists(self.csvName)
         isExist = os.path.exists(self.csvName)
-        csvfile = open(self.csvName,"a", encoding='utf-8-sig')
+        csvfile = open(self.csvName, "a", encoding="utf-8-sig")
         writer = csv.writer(csvfile)
         writer = csv.writer(csvfile)
-            #先写入columns_name
+        # 先写入columns_name
         if isExist == False:
         if isExist == False:
-                writer.writerow(name_list)
+            writer.writerow(name_list)
         writer.writerows(data)
         writer.writerows(data)
         csvfile.close()
         csvfile.close()
+
     def startDispose(self):
     def startDispose(self):
         self.get_online_data.refresh_headers()
         self.get_online_data.refresh_headers()
         num = 0
         num = 0
         save_root_path = ""
         save_root_path = ""
         baseImages = []
         baseImages = []
-        resize=1600#定义标准
-        for index,image_data in enumerate(self.need_cutout_images):
+        resize = 1600  # 定义标准
+        for index, image_data in enumerate(self.need_cutout_images):
+            if image_data["need_cutout"] == False:
+                continue
             num += 1
             num += 1
             save_root_path = image_data["root_path"]
             save_root_path = image_data["root_path"]
             file_path = image_data["file_path"]
             file_path = image_data["file_path"]
@@ -164,86 +191,91 @@ class DealModelForm:
                 if original_pic.y > resize:
                 if original_pic.y > resize:
                     original_pic.resize_by_heigh(heigh=resize)
                     original_pic.resize_by_heigh(heigh=resize)
             buffer = io.BytesIO()
             buffer = io.BytesIO()
-            original_pic.im.save(buffer, format='JPEG')
+            original_pic.im.save(buffer, format="JPEG")
             buffer.seek(0)
             buffer.seek(0)
-            image_url = self.get_online_data.upload_pic(file_path=None,buffer=buffer)
+            image_url = self.get_online_data.upload_pic(file_path=None, buffer=buffer)
             baseImages.append(image_url)
             baseImages.append(image_url)
             self.need_cutout_images[index]["image_url"] = image_url
             self.need_cutout_images[index]["image_url"] = image_url
         data = {
         data = {
-            "base_image":baseImages,
-            "out_width":self.params.out_width,
-            "out_height":self.params.out_height,
-            "size_mode":self.params.size_mode,
-            "output_mode":self.params.output_mode,
+            "base_image": baseImages,
+            "out_width": self.params.out_width,
+            "out_height": self.params.out_height,
+            "size_mode": self.params.size_mode,
+            "output_mode": self.params.output_mode,
         }
         }
         result_json = self.get_online_data.model_form_segment(data)
         result_json = self.get_online_data.model_form_segment(data)
         generate_ids = result_json.get("generate_ids")
         generate_ids = result_json.get("generate_ids")
         saveParams = []
         saveParams = []
-        for idx,id in enumerate(generate_ids):
-            self.need_cutout_images[idx]['generate_id'] = id
-            # ['file_name', 'file_e', 'file_path', 
-            # 'file','root_path','need_cutout','image_url','generate_id','status'] 
+        for idx, id in enumerate(generate_ids):
+            if self.need_cutout_images[idx]["need_cutout"] == False:
+                continue
+            self.need_cutout_images[idx]["generate_id"] = id
+            # ['file_name', 'file_e', 'file_path',
+            # 'file','root_path','need_cutout','image_url','generate_id','status']
             item = self.need_cutout_images[idx]
             item = self.need_cutout_images[idx]
-            saveParams.append([item['file_name'],
-                               item['file_e'],
-                               item['file_path'],
-                               item['file'],
-                               item['root_path'],
-                               item['need_cutout'],
-                               item['image_url'],
-                               item['generate_id'],
-                               False,
-                               ])
-        self.addData2Csv(saveParams)
-        return self.need_cutout_images, save_root_path,generate_ids
+            saveParams.append(
+                [
+                    item["file_name"],
+                    item["file_e"],
+                    item["file_path"],
+                    item["file"],
+                    item["root_path"],
+                    item["need_cutout"],
+                    item["image_url"],
+                    item["generate_id"],
+                    False,
+                ]
+            )
+        # self.addData2Csv(saveParams)
+        return self.need_cutout_images, save_root_path, generate_ids
+
     def search_progress(self):
     def search_progress(self):
-        try:
-          csvData = pd.read_csv(self.csvName)
-        except FileNotFoundError as e:
-          raise UnicornException("不存在生成记录,请先提交抠人台抠图任务")
-        '''进度查询'''
-        print("self.params",self.params)
+        # try:
+        #     csvData = pd.read_csv(self.csvName)
+        # except FileNotFoundError as e:
+        #     raise UnicornException("不存在生成记录,请先提交抠人台抠图任务")
+        """进度查询"""
+        print("self.params", self.params)
         search_generate_ids = self.params.generate_ids
         search_generate_ids = self.params.generate_ids
         dataParams = {
         dataParams = {
-            "generate_ids":search_generate_ids,
-            "type":self.params.type,
+            "generate_ids": search_generate_ids,
+            "type": self.params.type,
         }
         }
         responseData = self.get_online_data.search_progress(dataParams)
         responseData = self.get_online_data.search_progress(dataParams)
-        generate_ids = csvData.loc[csvData['generate_id'].isin(search_generate_ids)]
+        generate_ids = self.params.result
         successCount = 0
         successCount = 0
         failCount = 0
         failCount = 0
         is_finished = False
         is_finished = False
         root_path = ""
         root_path = ""
         if len(generate_ids) > 0:
         if len(generate_ids) > 0:
-            print("generate_ids",generate_ids)
-            root_path = generate_ids.iloc[0]["root_path"]
+            print("generate_ids", generate_ids)
+            root_path = generate_ids[0]["root_path"]
             save_path = f"{root_path}/已扣图"
             save_path = f"{root_path}/已扣图"
             check_path(save_path)
             check_path(save_path)
         else:
         else:
-            return is_finished,successCount,failCount,root_path
-        for idx,generate in generate_ids.iterrows():
+            return is_finished, successCount, failCount, root_path
+        for idx, generate in enumerate(generate_ids):
             for respItem in responseData:
             for respItem in responseData:
-                if generate["generate_id"]!=respItem["id"]:
+                if generate["generate_id"] != respItem["id"]:
                     continue
                     continue
-                status = respItem['status']
-                print("status",status)
+                status = respItem["status"]
+                print("status", status)
                 if status == -1:
                 if status == -1:
-                    failCount+=1
-                    csvData.drop(csvData.loc[csvData['generate_id'] == generate["generate_id"]].index,inplace=True)
+                    failCount += 1
                     break
                     break
                 if status == 2:
                 if status == 2:
-                    successCount+=1
-                    result_image_url = respItem['result_image_urls'][0]
+                    successCount += 1
+                    result_image_url = respItem["result_image_urls"][0]
                     result_image_pil = urlPilImage(result_image_url)
                     result_image_pil = urlPilImage(result_image_url)
                     root_path = generate["root_path"]
                     root_path = generate["root_path"]
                     file_name = generate["file_name"]
                     file_name = generate["file_name"]
                     file_e = generate["file_e"]
                     file_e = generate["file_e"]
-                    if result_image_pil.mode == 'RGBA':
+                    if result_image_pil.mode == "RGBA":
                         result_image_pil.save(f"{save_path}/{file_name}.png")
                         result_image_pil.save(f"{save_path}/{file_name}.png")
                     else:
                     else:
                         result_image_pil.save(f"{save_path}/{file_name}.jpg")
                         result_image_pil.save(f"{save_path}/{file_name}.jpg")
-                    csvData.drop(csvData.loc[csvData['generate_id'] == generate["generate_id"]].index,inplace=True)
                     break
                     break
-        csvData.to_csv(self.csvName, index=False)
-        is_finished = True if len(search_generate_ids) == (successCount+failCount) else False
-        return is_finished,successCount,failCount,root_path
+        is_finished = (
+            True if len(search_generate_ids) == (successCount + failCount) else False
+        )
+        return is_finished, successCount, failCount, root_path