Переглянути джерело

Merge branch 'dev-python'

rambo 11 місяців тому
батько
коміт
eba344ebf9

+ 2 - 1
python/.gitignore

@@ -3,4 +3,5 @@
 .venv/*
 dist/*
 *.log
-log/*
+log/*
+*.csv

+ 21 - 5
python/api.py

@@ -1,7 +1,7 @@
 from models import *
 import datetime
 from services.SegmentService import SegmentService
-from services.deal_cutout import DealCutout, DealCloths
+from services.deal_cutout import DealCutout, DealCloths,DealModelForm
 from services.other.module_online_data import GetOnlineData
 
 
@@ -24,11 +24,11 @@ async def checkSelect(params: CheckSelectImages):
         need_cutout_images = service.initImages(image_list=image_list)
     else:
         need_cutout_images = service.check_need_cutout_images(file_path)
+    print("need_cutout_images",need_cutout_images)
     if len([x for x in need_cutout_images if x["need_cutout"]]) == 0:
-        raise UnicornException("您所选文件夹下没有jpg图片,或对应图片已扣图")
+        raise UnicornException("您所选文件夹下没有jpg/png图片,或对应图片已扣图")
     return success(need_cutout_images)
 
-
 @app.post("/api/segment_images", description="执行抠图操作")
 async def segmentImages(params: SegmentImages):
     token = params.token
@@ -41,7 +41,7 @@ async def segmentImages(params: SegmentImages):
     result = None
     try:
         if len([x for x in need_cutout_images if x["need_cutout"]]) == 0:
-            raise UnicornException("您所选文件夹下没有jpg图片,或对应图片已扣图")
+            raise UnicornException("您所选文件夹下没有jpg/png图片,或对应图片已扣图")
     except Exception as e:
         raise UnicornException(repr(e))
     if image_type == 1 and segment_type == 1:
@@ -67,7 +67,23 @@ async def segmentImages(params: SegmentImages):
         result, save_root_path = deal_cutout_mode.normalMode()
     return success({"result": result, "save_root_path": save_root_path})
 
-
+@app.post("/api/model_form_segment", description="人台抠图")
+def model_form_segment(params:ModelFormModel):
+    need_cutout_images = params.need_cutout_images
+    try:
+        if len([x for x in need_cutout_images if x["need_cutout"]]) == 0:
+            raise UnicornException("您所选文件夹下没有jpg/png图片,或对应图片已扣图")
+    except Exception as e:
+        raise UnicornException(repr(e))
+    modelFormClazz = DealModelForm(token=params.token,params=params)
+    modelFormClazz.need_cutout_images = need_cutout_images
+    result ,save_root_path,generate_ids=modelFormClazz.startDispose()
+    return success({"result": result, "save_root_path": save_root_path,"generate_ids":generate_ids})
+@app.post("/api/search_bacth_progress", description="人台抠图")
+def search_bacth_progress(params:SearchProgress):
+    modelFormClazz = DealModelForm(token=params.token,params=params)
+    is_finished,successCount,failCount = modelFormClazz.search_progress()
+    return success({"is_finished": is_finished,"success_count": successCount,"fail_count": failCount})
 @app.post("/api/request_hlm", description="请求查询可用余额")
 async def requestHlm(params: RequestHlm):
     token = params.token

+ 2 - 2
python/config.ini

@@ -1,8 +1,8 @@
 [app]
 host=127.0.0.1
 port=7074
-debug=false
-env=prod
+debug=true
+env=dev
 [dev]
 origin= http://my2.pubdata.cn
 host= mybackend2.pubdata.cn

+ 2 - 8
python/main.py

@@ -38,12 +38,6 @@ if __name__ == "__main__":
     host = config.get("app", "host")
     port = config.get("app", "port")
     debug = config.get("app", "debug")
-    if debug == True:
-        uvicorn.run(app="api:app", host=host, port=int(port), reload=debug, loop="auto")
-    else:
-        uvicorn.run(app="api:app", host=host, port=int(port), loop="auto")
-    # app.run(port=int(port), use_reloader=bool(debug))
-    # 注册信号处理函数
-    # 或许flask内置的stdio与node.js stdio有冲突,导致控制台无法显示信息。
-    # 如果想要查看控制台输出,请单独启动服务 npm run dev-python
+    isDebug = True if debug =='true' else False
+    uvicorn.run(app="api:app", host=host, port=int(port), reload=isDebug, loop="auto")
     print("python server is running at port:", port)

+ 15 - 0
python/models.py

@@ -41,3 +41,18 @@ class SegmentImages(BaseModel):
 class RequestHlm(BaseModel):
     # 抠图
     token: str = Field(default=None, description="hlm token信息")
+
+class ModelFormModel(BaseModel):
+    '''人台抠图参数'''
+    token: str = Field(default=None, description="hlm token信息")
+    out_width: int = Field(default=1024, description="宽度;默认1024(仅在尺寸模式为【指定大小】时生效)")
+    out_height: int = Field(default=1024, description="高度;默认1024(仅在尺寸模式为【指定大小】时生效)")
+    size_mode: int = Field(default=0, description="尺寸模式;0=>指定大小;1=>最小边框")
+    output_mode: int = Field(default=0, description="输出模式;0=>透明底;1=>白底图")
+    need_cutout_images: list = Field(default=None, description="图像地址集合")
+    
+class SearchProgress(BaseModel):
+    '''进度查询'''
+    token: str = Field(default=None, description="hlm token信息")
+    generate_ids:list[int] = Field(default=[], description="生成记录ID数组")
+    type: str = Field(default="aigc_pro", description="进度类型")

BIN
python/requirements.txt


+ 5 - 3
python/services/SegmentService.py

@@ -42,6 +42,8 @@ class SegmentService:
             ".JPG",
             ".jpeg",
             ".JPEG",
+            ".png",
+            ".PNG",
         ]
         _is_cutout = []
         if not os.path.isdir(root_path):
@@ -59,6 +61,7 @@ class SegmentService:
                     for x_file in os.listdir(file_path):
                         x_file_name, x_file_e = os.path.splitext(x_file)
                         if x_file_e == ".png":
+                            continue
                             _is_cutout.append(x_file_name)
 
         # ===============================================================
@@ -87,10 +90,9 @@ class SegmentService:
                     break
             if not f:
                 continue
-
             need_cutout = False if file_name in _is_cutout else True
-            if os.path.exists("{}/{}.png".format(root_path, file_name)):
-                need_cutout = False
+            # if os.path.exists("{}/{}.png".format(root_path, file_name)):
+            #     need_cutout = False
 
             # 图片进行处理
             need_cutout_images.append(

+ 117 - 2
python/services/deal_cutout.py

@@ -5,8 +5,17 @@ from .remove_bg_pixian import RemoveBgPiXian
 from .other.module_online_data import GetOnlineData
 from .deal_one_image import DealOneImage, DealOneImageBeforehand
 from .other.log import MyLogger
-
-
+from models import  UnicornException
+import pandas as pd
+import csv
+from PIL import Image
+from io import BytesIO
+import os,requests
+def urlPilImage(url):
+    yzmdata = requests.get(url)
+    tempIm = BytesIO(yzmdata.content)
+    im = Image.open(tempIm)
+    return im
 class DealCutout:
 
     def __init__(self, token):
@@ -105,3 +114,109 @@ class DealCloths:
             upload_pic_dict = hand.get_image_cut_cloths(image_data)
             result_array.append(upload_pic_dict)
         return result_array, save_root_path
+
+class DealModelForm:
+    '''处理人台抠图'''
+    csvName = 'record.csv'
+    def __init__(self, token,params):
+        super().__init__()
+        self.lock = threading.Lock()
+        self.need_cutout_images = {}
+        self.token = token
+        self.output_type = 0
+        self.state = 2  # 1进行中 2停止
+        self.get_online_data = GetOnlineData(self.token)
+        self.is_upload_pic_num = 0
+        self.is_deal_num = 0
+        self.params = params
+        # 图片列表
+        self.upload_pic_dict = {}
+        self.logger = MyLogger().logger
+    def addData2Csv(self,data):
+        name_list = ['file_name', 'file_e', 'file_path', 'file','root_path','need_cutout','image_url','generate_id','status'] 
+        isExist = os.path.exists(self.csvName)
+        csvfile = open(self.csvName,"a")
+        writer = csv.writer(csvfile)
+            #先写入columns_name
+        if isExist == False:
+                writer.writerow(name_list)
+        writer.writerows(data)
+        csvfile.close()
+    def startDispose(self):
+        self.get_online_data.refresh_headers()
+        num = 0
+        save_root_path = ""
+        baseImages = []
+        for index,image_data in enumerate(self.need_cutout_images):
+            num += 1
+            save_root_path = image_data["root_path"]
+            file_path = image_data["file_path"]
+            image_url = self.get_online_data.upload_pic(file_path)
+            baseImages.append(image_url)
+            self.need_cutout_images[index]["image_url"] = image_url
+        data = {
+            "base_image":baseImages,
+            "out_width":self.params.out_width,
+            "out_height":self.params.out_height,
+            "size_mode":self.params.size_mode,
+            "output_mode":self.params.output_mode,
+        }
+        result_json = self.get_online_data.model_form_segment(data)
+        generate_ids = result_json.get("generate_ids")
+        saveParams = []
+        for idx,id in enumerate(generate_ids):
+            self.need_cutout_images[idx]['generate_id'] = id
+            # ['file_name', 'file_e', 'file_path', 
+            # 'file','root_path','need_cutout','image_url','generate_id','status'] 
+            item = self.need_cutout_images[idx]
+            saveParams.append([item['file_name'],
+                               item['file_e'],
+                               item['file_path'],
+                               item['file'],
+                               item['root_path'],
+                               item['need_cutout'],
+                               item['image_url'],
+                               item['generate_id'],
+                               False,
+                               ])
+        self.addData2Csv(saveParams)
+        return self.need_cutout_images, save_root_path,generate_ids
+    def search_progress(self):
+        try:
+          csvData = pd.read_csv(self.csvName)
+        except FileNotFoundError as e:
+          raise UnicornException("不存在生成记录,请先提交抠人台抠图任务")
+        '''进度查询'''
+        print("self.params",self.params)
+        search_generate_ids = self.params.generate_ids
+        dataParams = {
+            "generate_ids":search_generate_ids,
+            "type":self.params.type,
+        }
+        responseData = self.get_online_data.search_progress(dataParams)
+        generate_ids = csvData.loc[csvData['generate_id'].isin(search_generate_ids)]
+        successCount = 0
+        failCount = 0
+        is_finished = False
+        for idx,generate in generate_ids.iterrows():
+            filtered_results = list(filter(lambda d: d.get('id') == 3, responseData))
+            
+        csvData.to_csv(self.csvName)
+        is_finished = True if len(search_generate_ids) == (successCount+failCount) else False
+        return is_finished,successCount,failCount
+    # def getElement(self,responseData):
+    #  for respItem in responseData:
+    #             status = respItem['status']
+    #             if status == -1:
+    #                 failCount+=1
+    #                 csvData.drop(csvData.loc[csvData['generate_id'] == generate["generate_id"]].index,inplace=True)
+    #                 continue
+    #             if status == 2:
+    #                 successCount+=1
+    #                 result_image_url = respItem['result_image_urls'][0]
+    #                 result_image_pil = urlPilImage(result_image_url)
+    #                 root_path = generate["root_path"]
+    #                 file_name = generate["file"]
+    #                 result_image_pil.save(f"{root_path}/已扣图/{file_name}")
+    #                 csvData.drop(csvData.loc[csvData['generate_id'] == generate["generate_id"]].index,inplace=True)
+    #                 continue

+ 33 - 4
python/services/other/module_online_data.py

@@ -10,15 +10,13 @@ import os, io
 from PIL import Image
 from io import BytesIO
 import configparser, json
-
+from models import  UnicornException
 # ===============默认数据配置=====================
 config = configparser.ConfigParser()
 config_name = "config.ini"
 config.read(config_name)
 debug = config.get("app", "debug")
 env = config.get("app", "env")
-print("debug====>", debug)
-print("debug====>", debug,123)
 origin = config.get(env, "origin")
 host = config.get(env, "host")
 domain = config.get(env, "domain")
@@ -183,7 +181,38 @@ class GetOnlineData(object):
         # print(_s.text)
         response_data = _s.json()
         return response_data["data"]["url"]
-
+    def model_form_segment(self, data=None):
+        '''人台抠图任务api'''
+        url = "{domain}/api/ai_image/v2/model_form_segment".format(domain=domain)
+        headers = {
+            "Authorization": self.token,
+            "Origin": Origin,
+            "Host": Host,
+            "Content-Type": "application/json;charset=UTF-8",
+        }
+        paramsData = json.dumps(data)
+        _s = requests.post(url=url, headers=headers, data=paramsData, timeout=60)
+        response_data = _s.json()
+        responseCode = response_data['code']
+        if responseCode != 0:
+            raise UnicornException(response_data["message"])
+        return response_data["data"]
+    def search_progress(self, data=None):
+        '''进度查询'''
+        url = "{domain}/api/ai_image/main/search_bacth_progress".format(domain=domain)
+        headers = {
+            "Authorization": self.token,
+            "Origin": Origin,
+            "Host": Host,
+            "Content-Type": "application/json;charset=UTF-8",
+        }
+        paramsData = json.dumps(data)
+        _s = requests.post(url=url, headers=headers, data=paramsData, timeout=60)
+        response_data = _s.json()
+        responseCode = response_data['code']
+        if responseCode != 0:
+            raise UnicornException(response_data["message"])
+        return response_data["data"]
     def get_keys(self):
         k = "pxnib99dbchtmdm"
         s = "ub9uj5678gs4m2bnrass1t3tn6ughlk065ianosk06akagolcr2u"

+ 14 - 0
python/temp.py

@@ -0,0 +1,14 @@
+import pandas as pd
+import csv
+import os
+csvName = 'record.csv'
+try:
+    csvData = pd.read_csv(csvName)
+except FileNotFoundError as e:
+    print("不存在生成记录,请先提交抠人台抠图任务")
+generate_ids = csvData.loc[csvData['generate_id'].isin([1100559,1100560,1100562])]
+for idx,item in generate_ids.iterrows():
+    # print("idx",idx)
+    csvData.drop(csvData.loc[csvData['generate_id'] == item["generate_id"]].index,inplace=True)
+    # print("item",item)
+csvData.to_csv(csvName, index=False)