Python相关🔥

2023/9/26 python后端

pip 安装依赖包

# 清华源:https://pypi.tuna.tsinghua.edu.cn/simple

# 豆瓣源:http://pypi.douban.com/simple

# 阿里源:http://mirrors.aliyun.com/pypi/simple/

例如:
pip install 依赖包 --proxy=http://xx.xx.xx.xx:8888 -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
1
2
3
4
5
6
7
8

from Crypto.Cipher import AES 安装 Crypto

pip install pycryptodome==3.10.1
1

Python在线运行IDE: Jupyterlab

pip install jupyterlab -i  https://pypi.tuna.tsinghua.edu.cn/simple
jupyter lab --generate-config
vim ~/.jupyter/jupyter_lab_config.py
``
c.ServerApp.ip = '0.0.0.0'  # 替换为你的IP地址或使用'*'监听所有
c.ServerApp.port = 9988  # 替换为你希望的端口号
``
jupyter lab password
执行:jupyter lab
后台执行:nohup jupyter lab > jupyter.log 2>&1 &
1
2
3
4
5
6
7
8
9
10

Python搭建FTP服务

# -*- encoding: utf-8 -*-

import http.server
import socketserver
import base64

USERNAME = 'ftp_test'
PASSWORD = 'ftp_test123!!'

AUTH_KEY = base64.b64encode('{}:{}'.format(USERNAME, PASSWORD).encode()).decode()


class BasicAuthHandler(http.server.SimpleHTTPRequestHandler):

    def do_HEAD(self):
        self.send_response(200)
        self.send_header("Content-type", "text/html;charset=UTF-8")
        self.end_headers()

    def do_auth_head(self):
        self.send_response(401)
        self.send_header("WWW-Authenticate", 'Basic realm="FileServer"')
        self.send_header("Content-type", "text/html;charset=UTF-8")
        self.end_headers()

    def do_GET(self):
        """ Present frontpage with user authentication. """
        if self.headers.get("Authorization") is None:
            self.do_auth_head()
            self.wfile.write(b"no auth header received")
        elif self.headers.get("Authorization") == "Basic " + AUTH_KEY:
            super().do_GET()
        else:
            self.do_auth_head()
            self.wfile.write(self.headers.get("Authorization").encode())
            self.wfile.write(b"not authenticated")


class ThreadingHTTPServer(socketserver.ThreadingMixIn, http.server.HTTPServer):
    daemon_threads = True


if __name__ == '__main__':
    port = 9988
    address = ('', port)
    print('server listening at', address)
    with ThreadingHTTPServer(address, BasicAuthHandler) as httpd:
        httpd.serve_forever()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

Python搭建Mock服务

# -*- encoding: utf-8 -*-
from abc import ABC
from tornado.options import options

import tornado.httpserver
import tornado.ioloop
import tornado.web
import tornado.log
import logging
import pymysql
import time
import json
import json5
import os
import re
import redis

redis_conn = redis.Redis(host='xx.xx.xx.xx', port=6379, db=1)


class MyDB(object):
    def __init__(self, host, port, user, pwd, dbname):
        self.conn = pymysql.connect(host=host, port=port, user=user, passwd=pwd, db=dbname, charset='utf8')
        self.conn.autocommit(True)
        self.cur = self.conn.cursor(cursor=pymysql.cursors.DictCursor)

    def query(self, sql):
        results = None
        try:
            self.cur.execute(sql)
            results = self.cur.fetchall()
        except Exception as e:
            print(e)

        msg = f"\n\n【SQL】 {sql}\n"
        msg += f"【Results】\n {results}\n"
        logging.info(msg)

        return results

    def exec(self, sql):
        try:
            self.cur.execute(sql)
            self.conn.commit()
        except Exception as e:
            print(e)

    def close(self):
        self.cur.close()
        self.conn.close()


class LogFormatter(tornado.log.LogFormatter):
    def __init__(self):
        super(LogFormatter, self).__init__(
            fmt='%(color)s[%(asctime)s %(filename)s:%(funcName)s:%(lineno)d %(levelname)s]%(end_color)s %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S'
        )


class MockRuleHandler(object):
    def __init__(self):
        self.db = MyDB(options.mysql_host, options.mysql_port, options.mysql_user, options.mysql_pwd, options.mysql_db)
        self.api_id = None
        self.http_code = None
        self.delay_time = None
        self.response = None

    def check_response(self, api_info, resp):
        if len(resp) != 0:
            self.response = resp[0].get('response')
        else:
            self.response = api_info.get('response')

    def check_decision(self, api_info, params, body, _path):
        body = json5.loads(body) if body is not None and body != b'' else {}
        mobile = params.get('mobile')[0].decode() if 'mobile' in params else ''
        mobile = body.get('mobile') if 'mobile' in body else mobile
        identity_no = params.get('identity_no')[0].decode() if 'identity_no' in params else ''
        identity_no = body.get('identity_no') if 'identity_no' in body else identity_no

        sql = f"select d_type from mock_decision where mobile='{mobile}' or identity_no='{identity_no}' order by create_time desc;"
        resp = self.db.query(sql)
        if len(resp) == 0:
            self.response = None
        else:
            rule_dict = {
                "1": "A决策",
                "2": "B决策",
                "3": "A-次级捞回"
            }
            rule_name = rule_dict.get(str(resp[0].get('d_type')), "1") or 'A决策'
            # sql = f"select response from mock_api_rules where name='{rule_name}' and enabled=true order by update_time desc limit 1;"
            sql = f"""
                select * from mock_api_rules as mar 
                left join mock_api_info as mai on mar.api_id = mai.id
                where mar.name='{rule_name}' and mar.enabled=true and mai.`path` ='{_path}'
                order by mar.update_time desc
                limit 1;
            """
            resp = self.db.query(sql)
            self.check_response(api_info, resp)

    def check_rules(self, api_info):
        sql = f"select response from mock_api_rules where api_id='{self.api_id}' and enabled=true order by update_time desc limit 1;"
        resp = self.db.query(sql)
        self.check_response(api_info, resp)

    def decision_a(self, api_info, _path):
        # sql = f"select response from mock_api_rules where api_id='{self.api_id}' and name='A决策' order by update_time desc limit 1;"
        sql = f"""
                select * from mock_api_rules as mar 
                left join mock_api_info as mai on mar.api_id = mai.id
                where mar.name='A决策' and mar.enabled=true and mai.`path` ='{_path}'
                order by mar.update_time desc
                limit 1;
        """
        resp = self.db.query(sql)
        self.check_response(api_info, resp)

    def match_rule(self, path, params, body):
        self.api_id = re.findall(r"\d+\.?\d*", path)[0]
        _path = path.split(self.api_id)[1]
        _key = f'{self.api_id}_result'
        # 判断缓存
        value = redis_conn.get(_key)
        if value is None:
            # 如果缓存中没有值,进行计算或数据库查询
            logging.info('开始查询Mock接口是否存在')
            sql = f"select response,delay_time,http_code,`check` from mock_api_info where id='{self.api_id}' and path='{_path}';"
            api_info = self.db.query(sql)
            if len(api_info) == 0:
                self.response = {"response": "无此接口,请检查接口信息是否正确"}
                self.http_code = 404
                self.delay_time = 0
            else:
                api_info = api_info[0]
                self.http_code = api_info.get('http_code')
                self.delay_time = api_info.get('delay_time')
                check = api_info.get('check')

                if check:  # 校验入参
                    logging.info('开始校验决策入参逻辑,匹配决策数据,匹配到取决策返回值')
                    # 按是否匹配决策数据,匹配到取决策返回值
                    self.check_decision(api_info, params, body, _path) if self.response is None else None

                    logging.info('开始校验决策入参逻辑,未匹配到决策数据,默认走A決策')
                    # 未匹配到决策数据,默认走A决策
                    self.decision_a(api_info, _path) if self.response is None else self.response

                    logging.info('开始校验规则逻辑')
                    # 校验规则,启用按最新时间的值返回
                    self.check_rules(api_info) if self.response is None else self.response

                    # 默认返回值
                    self.response = api_info.get('response') if self.response is None else self.response
                else:
                    self.response = api_info.get('response')

            self.close()
            self.response = json.dumps(json5.loads(self.response))  # 去除注释

            # 非缓存接口清单
            no_cache_api_list = ['110', '216']
            if self.api_id not in no_cache_api_list:
                # 将结果存储在缓存中,以便下次使用
                result = {
                    "http_code": self.http_code,
                    "delay_time": self.delay_time,
                    "response": self.response
                }
                redis_conn.set(_key, json.dumps(result))
        else:
            logging.info(f'缓存中存在{self.api_id},{path}接口信息,直接返回')
            value = json.loads(value)
            self.http_code = value.get('http_code')
            self.delay_time = value.get('delay_time')
            self.response = value.get('response')

    def close(self):
        self.db.close()


class MockApiHandler(tornado.web.RequestHandler, ABC):
    def get_response(self, path):
        mock_rule = MockRuleHandler()
        mock_rule.match_rule(path, self.request.query_arguments, self.request.body)
        resp = mock_rule.response
        delay_time = mock_rule.delay_time
        http_code = mock_rule.http_code

        if delay_time != 0:
            time.sleep(delay_time)

        self.format_log(resp)
        self.set_status(http_code)
        self.set_header('Content-type', 'application/json;charset=UTF-8')
        try:
            resp = json.loads(resp)
            self.write(resp)
        except ValueError as e:
            logging.error(e.__str__())
            self.write({"success": False, "err": e.__str__()})

    def format_log(self, resp):
        logging.info('请求完整信息')
        msg = f"\n\n【Headers】:\n{self.request.headers}"
        msg += f"【Method】:\n        {self.request.method}\n"
        msg += f"【Path】:\n        {self.request.uri}\n"
        msg += f"【Params】:\n        {self.request.query_arguments}\n"
        msg += f"【Body】:\n        {self.request.body}\n"
        msg += f"【Response】:\n        {resp}\n"
        logging.info(msg)

    async def get(self, *args):
        self.get_response(self.request.path)

    async def post(self, *args):
        self.get_response(self.request.path)

    async def put(self, *args):
        self.get_response(self.request.path)

    async def patch(self, *args):
        self.get_response(self.request.path)

    async def delete(self, *args):
        self.get_response(self.request.path)


if __name__ == '__main__':
    options.define("address", default="0.0.0.0", help="run on the given address", type=str)
    options.define("port", default=8891, help="run on the given port", type=int)

    options.define("mysql_host", default="xx.xx.xx.xx", help="mysql host", type=str)
    options.define("mysql_port", default=3306, help="mysql port", type=int)

    options.define("mysql_user", default="xxxx", help="mysql user", type=str)
    options.define("mysql_pwd", default="xxxxxx", help="mysql password", type=str)
    options.define("mysql_db", default="xxx", help="mysql dbname", type=str)

    options.log_file_prefix = os.path.join(os.path.dirname(__file__), 'server.log')
    options.parse_command_line()
    # 定义app
    app = tornado.web.Application(
        handlers=[
            (r"/(.*)", MockApiHandler)
        ],
    )
    [log.setFormatter(LogFormatter()) for log in logging.getLogger().handlers]
    http_server = tornado.httpserver.HTTPServer(app)
    http_server.listen(options.port, options.address)
    http_server.start(0)  # 默认开启跟cpu相同内核数的进程
    print('server started')
    tornado.ioloop.IOLoop.instance().start()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255

LangChain SQL Agent后端服务 - 数据猎手

from colorama import Fore
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import create_sql_agent
from langchain_openai import ChatOpenAI
from openai import OpenAI

import uvicorn
import json
import os

os.environ["HTTP_PROXY"] = "xx.xx.xx.xx:xxxx"
os.environ["HTTPS_PROXY"] = "xx.xx.xx.xx:xxxx"
os.environ["OPENAI_API_KEY"] = "xxxxxxxxxxxxxxx"

app = FastAPI()

# 从MySQL URI创建SQLDatabase实例
# 替换下面的用户名、密码、主机、端口和数据库名称为你的实际配置
db = SQLDatabase.from_uri("mysql+pymysql://user:password@xx.xx.xx.xx:xxxx/database")
print(db.dialect)  # 打印数据库方言,用于识别和适应不同的SQL方言
print(db.get_usable_table_names())  # 打印可用的表名列表
# db.run("SELECT * FROM Artist LIMIT 10;")  # 执行SQL查询,并打印结果

llm = ChatOpenAI(model_name="gpt-4o-2024-05-13", temperature=0)
# llm = ChatOpenAI(
#     model_name='model_name',
#     openai_api_base='http://xx.xx.xx.xx:xxxx/v1',
#     openai_api_key='EMPTY',
#     streaming=True,
# )

# response = llm.stream(
#             input="你好",  # Chat history
#             # temperature=0,  # Temperature for text generation
#         )
# for chunk in response:
#     content = chunk.content or ""
#     print(Fore.GREEN + content, end="", flush=True)

agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)


class QueryRequest(BaseModel):
    question: str


@app.post("/query")
async def query(request: QueryRequest):
    question = request.question

    if not question:
        raise HTTPException(status_code=400, detail="No question provided")
    result = None
    try:
        # 使用agent_executor执行查询
        message = (
            "你是智能SQL查询助理,专门负责根据用户的问题,高效查询数据库信息,"
            f"用户问题:{question}"
            "返回结果必须按如下json格式,不包含其他信息"
            "{\"query_sql\": \"query_sql\",\"result_type\": \"text or code\",\"result_data\": \"result_data\"}"
            "如果是画图意图,使用python代码生成图片数据后保存到内存中,最后返回缓冲区内容到image_result变量中"
            "如果返回的result_type是text,直接展示result_data"
            "如果是code,result_data内容使用python代码实现"
        )
        print(message)
        result = agent_executor.invoke({"input": message})
        print("result below: ")
        print(result)
        format_result = json.loads(result["output"][7:-3])

        return format_result["query_sql"], format_result["result_type"], format_result["result_data"]
    except Exception as e:
        print(str(e))
        return "", "text", result["output"]
        # raise HTTPException(status_code=500, detail=str(e))


@app.post("/image")
async def image(request: QueryRequest):
    question = request.question

    if not question:
        raise HTTPException(status_code=400, detail="No question provided")

    try:
        client = OpenAI()

        response = client.images.generate(
            model="dall-e-3",
            response_format="url",  # "url","b64_json"
            prompt=question,
            size="1024x1024",
            quality="standard",
            n=1,
        )
        image_data = response.data[0].url
        return {"image_data": image_data}
    except Exception as e:
        print(str(e))
        return {"image_data": None}


# 运行应用时,使用 `uvicorn main:app --reload` 命令启动服务器
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

Python搭建Web服务 - Gradio - 数据猎手

import io
import json
import gradio as gr
import requests
from PIL import Image

# URL of the FastAPI endpoint
FASTAPI_URL = "http://localhost:8000"


def query_api(question):
    response = requests.post(FASTAPI_URL + '/query', json={"question": question})
    if response.status_code == 200:
        data = response.json()

        return data[0], data[1], data[2]
    else:
        return {"error": response.json().get("detail", "An error occurred")}


def text_to_speech(content, yz_flag, ys_flag, yl_flag):
    audio_result = generate_audio(content, yz_flag, ys_flag, yl_flag)
    return audio_result


def text_to_image(content):
    response = requests.post(FASTAPI_URL + '/image', json={"question": content})
    if response.status_code == 200:
        data = response.json()
        print(data)
        return data["image_data"]
    else:
        return {"error": response.json().get("detail", "An error occurred")}


def drawer(code):
    # 这个函数用于执行代码
    globals_, locals_ = {}, {}
    exec(code, globals_, locals_)
    image_bytes = locals_.get('image_result')
    image = Image.open(io.BytesIO(image_bytes))

    return image


def generate_audio(content, yz_flag, ys_flag, yl_flag):
    try:
        url = "http://xx.xx.xx.xx:xxxx/api/audios/generations"
        yz_flag_dict = {
            'zh-CN-XiaoxiaoNeural': '中文 女性',
            'zh-CN-YunjianNeural': '中文 男性',
            'en-GB-LibbyNeural': '英语 女性',
            'en-GB-RyanNeural': '英语 男性'
        }
        voice = next((key for key, value in yz_flag_dict.items() if value == yz_flag), '中文 女性')

        payload = {
            "text": content,
            "voice": voice,
            "rate": f"+{ys_flag}%",
            "volume": f"+{yl_flag}%"
        }
        print("TTS请求参数:")
        print(json.dumps(payload, ensure_ascii=False))
        headers = {
            'Accept': '*/*',
            'Accept-Language': 'zh-CN,zh;q=0.9',
            'Authorization': 'AIACC',
            'Cache-Control': 'no-cache',
            'Connection': 'keep-alive',
            'Content-Type': 'application/json',
            'Origin': 'null',
            'Pragma': 'no-cache',
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36'
        }

        response = requests.request("POST", url, headers=headers, data=json.dumps(payload))
        print("TTS请求响应码:" + str(response.status_code))
        # 确认请求成功
        if response.status_code == 200:
            # 获取音频二进制内容
            audio_content = response.content
            # 使用BytesIO包装音频内容,以便Gradio可以播放
            # audio_file = io.BytesIO(audio_content)
            # audio_file.seek(0)  # 确保位置指针在开始位置
            print("TTS请求内容:" + str(audio_content))
            return audio_content
        else:
            # 处理错误情况
            print(f"Failed to get TTS: {response.status_code}")
            return None
    except Exception as e:
        print(f"Failed to get TTS: {str(e)}")
        return None


def clear_contents():
    # 这个函数返回每个组件的默认值,以清除内容
    return '', '', 'text', '', None, '', None, '', None


def main():
    with gr.Blocks() as demo:
        # 定义样例数据
        examples = [
            "查询全部的表名",
            "查询xxx表设计并给出相应的优化建议",
            "查询xxx表,按月统计出前五条记录最多的月份,查询后的结果画柱状图",
            "查询xxx表,按月统计出前五条记录最多的月份,查询出的结果画饼图",
            "查询xxx表,按月统计出前五条记录最多的月份,查询出的结果画折线图",
            "查询xxx表,按月统计出前五条记录最多的月份,查询出的结果画散点图",
            "查询xxx表,按月统计出前五条记录最多的月份,查询出的结果画词云图"
        ]
        gr.Markdown("## 数据猎手")
        with gr.Row():
            question_input = gr.TextArea(label="输入", placeholder="输入你的问题,例如:查询某表设计并给出相应的优化建议")
            # 添加样例链接
            gr.Examples(examples, inputs=question_input)

        result_sql = gr.Textbox(label="查询语句")
        result_type = gr.Radio(["text", "code"], label="结果类型", visible=False)

        with gr.Tab("输出"):
            with gr.Row():
                text_result = gr.TextArea(label="文本", placeholder="输出的内容")
                text_image = gr.Image(label="图表")

            with gr.Row():
                # 添加查询按钮
                submit_button = gr.Button("查 询", variant="primary")
                submit_button.click(
                    fn=query_api,
                    inputs=[question_input],
                    outputs=[result_sql, result_type, text_result]
                )
                # 添加生成图像按钮
                execute_button = gr.Button("代码转图片", variant="primary")
                execute_button.click(
                    fn=drawer,
                    inputs=[text_result],
                    outputs=text_image
                )

        with gr.Tab("音频"):
            with gr.Row():
                audio_text = gr.TextArea(label="文本", placeholder="输入文字内容")
                audio_result = gr.Audio(label="音频")

            # 创建一个Accordion组件,其中包含additional_inputs
            with gr.Accordion("其他设置", open=False):
                with gr.Row():
                    yz_flag = gr.Dropdown(
                        ["中文 女性", "中文 男性", "英文 女性", "英文 男性"],
                        label="语种",
                        value="中文 女性"
                    )
                    ys_flag = gr.Slider(minimum=0, maximum=100, label="语速", value=30)
                    yl_flag = gr.Slider(minimum=0, maximum=100, label="音量", value=60)

            # 添加生成音频按钮
            audio_button = gr.Button("文字转语音", variant="primary")
            audio_button.click(
                fn=text_to_speech,
                inputs=[audio_text, yz_flag, ys_flag, yl_flag],
                outputs=audio_result
            )

        with gr.Tab("图片"):
            with gr.Row():
                image_text = gr.TextArea(label="文本", placeholder="输入文字内容")
                image_result = gr.Image(label="图片")

            # 添加生成图片按钮
            audio_button = gr.Button("文字转图片", variant="primary")
            audio_button.click(
                fn=text_to_image,
                inputs=[image_text],
                outputs=image_result
            )

        with gr.Row():
            # 添加清除所有内容按钮
            clear_button = gr.Button("重 置")
            clear_button.click(
                fn=clear_contents,
                inputs=[],
                outputs=[
                    question_input,
                    result_sql,
                    result_type,
                    text_result,
                    text_image,
                    image_text,
                    image_result,
                    audio_text,
                    audio_result
                ]
            )

    # demo.launch(server_name='0.0.0.0', server_port=7860, show_error=True, auth=("admin", "admin1234"))
    # demo.launch(share=True)
    demo.launch()


if __name__ == "__main__":
    main()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206

windows 端口映射规则管理器

from tkinter import ttk
from tkinter import messagebox
import tkinter as tk
import subprocess
import socket


def check_fields(listen_port, rule_name):
    # 检查listen_port和rule_name是否为空
    if not listen_port or not rule_name:
        messagebox.showwarning("Error", "Please ensure Listen Port and Rule Name are not empty.")
        return False
    else:
        return True


def add_rule():
    listen_address = address_entry.get()
    listen_port = port_entry.get()
    protocol = protocol_entry.get()
    rule_name = rule_name_entry.get()

    if not check_fields(listen_port, rule_name):
        return

    commands = [
        f"netsh interface portproxy add v4tov4 listenaddress={listen_address} listenport={listen_port} protocol={protocol} connectaddress=127.0.0.1 connectport={listen_port}",
        f"netsh advfirewall firewall add rule name={rule_name} dir=in action=allow protocol=TCP localport={listen_port}"
    ]
    for cmd in commands:
        print(f"Adding rule: {cmd}")
        result = subprocess.run(cmd, shell=True, capture_output=True)
        print(result.stdout.decode('gbk'))
    messagebox.showinfo("Success", "Portproxy and firewall rule added.")

    # 更新端口号列表
    port_entry['values'] = get_port_list()
    port_entry.set('')
    rule_name_entry.set('')


def delete_rule():
    listen_address = address_entry.get()
    listen_port = port_entry.get()
    protocol = protocol_entry.get()
    rule_name = rule_name_entry.get()

    if not check_fields(listen_port, rule_name):
        return

    commands = [
        f"netsh interface portproxy delete v4tov4 listenaddress={listen_address} listenport={listen_port} protocol={protocol}",
        f"netsh advfirewall firewall delete rule name={rule_name}"
    ]
    for cmd in commands:
        print(f"Deleting rule: {cmd}")
        result = subprocess.run(cmd, shell=True, capture_output=True)
        print(result.stdout.decode('gbk'))
    messagebox.showinfo("Success", "Portproxy and firewall rule deleted.")

    # 更新端口号列表
    port_entry['values'] = get_port_list()
    port_entry.set('')
    rule_name_entry.set('')


def query_proxy_rule():
    listen_port = port_entry.get()
    # 检查listen_port和rule_name是否为空
    if not listen_port:
        messagebox.showwarning("Error", "Please ensure Listen Port are not empty.")
        return
    cmd = "netsh interface portproxy show all"
    print(f"Querying: {cmd}")
    result = subprocess.check_output(cmd, shell=True)
    try:
        result = result.decode('utf-8')  # 尝试解码为UTF-8
    except UnicodeDecodeError:
        result = result.decode('gbk')  # 如果UTF-8失败,尝试GBK
    print(result)
    if f"{listen_port}" not in result:
        messagebox.showinfo("Result", f"No portproxy rule found for the given port {listen_port}.")
    else:
        messagebox.showinfo("Result", f"Portproxy rule found for {listen_port}.")


def get_port_list():
    cmd = "netsh interface portproxy show all"
    print(f"Querying: {cmd}")
    result = subprocess.check_output(cmd, shell=True)
    try:
        result = result.decode('utf-8')  # 尝试解码为UTF-8
    except UnicodeDecodeError:
        result = result.decode('gbk')  # 如果UTF-8失败,尝试GBK
    port_list = []
    for line in result.split("\n"):
        if "127.0.0.1" in line:
            port = line.split()[3]
            port_list.append(port)
    return port_list


def get_file_wall_list(event):
    port = event.widget.get()
    cmd = "netsh advfirewall firewall show rule name=all"
    print(f"Querying: {cmd}")
    result = subprocess.check_output(cmd, shell=True)
    try:
        result = result.decode('utf-8')  # 尝试解码为UTF-8
    except UnicodeDecodeError:
        result = result.decode('gbk')  # 如果UTF-8失败,尝试GBK
    rules = parse_firewall_rules(result, port)

    # 返回规则名称列表
    rule_name_list = [rule['规则名称'] for rule in rules]
    rule_name_entry['values'] = rule_name_list
    rule_name_entry.set(rule_name_list[0] if len(rule_name_list) != 0 else '')


def parse_firewall_rules(output, port):
    rules = []
    rule = {}
    lines = output.split('\n')
    name_list = ['HTTP-In', 'HTTP-Out', 'HTTP-Streaming-In', 'NB-Session-In', 'NP-In', 'PPTP-In', 'RTSP-Streaming-In',
                 'SMB-In', 'SSDP TCP-In', 'SSTP-In', 'TCP', 'TCP-In', 'TCP-WS-In', 'TCP-WSS-In', 'UPnP-In',
                 'WSD Events-In', 'WSD EventsSecure-In', 'iWARP-In', 'qWave-TCP-In', 'DCOM-In', '@FirewallAPI',
                 '家庭组输入']  # 过滤规则名称
    rule_name_index = 0
    for index, line in enumerate(lines):
        if line.startswith('------'):
            rule_name_index = index
            if rule:
                rules.append(rule)
                rule = {}
        else:
            try:
                key, value = line.split(': ', 1)
                # 检查value是否为空,如果为空则直接跳过
                if not value:
                    continue
                rule[key.strip()] = value.strip()
            except ValueError:
                # 这里可以记录日志或者处理这种特殊情况下数据的其他方式
                pass
            if "规则名称" in rule and "本地端口" in rule:
                if rule["本地端口"] not in ['任何', 'RPC', 'RPC-EPMap']:
                    if rule["协议"] == 'TCP':
                        rule["规则名称"] = lines[rule_name_index - 1].split(': ')[1].strip()
                        if all(name not in rule["规则名称"] for name in name_list):
                            print(rule["规则名称"], rule["本地端口"], rule["协议"])

    # 处理最后一个规则
    if rule:
        rules.append(rule)
    # 过滤掉没有端口的规则
    # 过滤掉没有端口的规则,并且本地端口不在['任何','RPC','RPC-EPMap']中,协议为'TCP'
    filter_rules = []
    for rule in rules:
        if '本地端口' in rule and rule['本地端口'] == port:
            if rule["本地端口"] not in ['任何', 'RPC', 'RPC-EPMap'] and rule["协议"] == 'TCP':
                if rule["规则名称"][:5] not in name_list:
                    filter_rules.append(rule)
    return filter_rules


def on_exit():
    root.destroy()


def get_local_ip():
    """
    查询本机ip地址
    :return: ip
    """
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        s.connect(('8.8.8.8', 80))
        ip = s.getsockname()[0]
    finally:
        s.close()

    return ip


root = tk.Tk()
root.title("Portproxy and Firewall Rule Manager")
root.geometry("500x300")

# 获取屏幕分辨率
screen_width = root.winfo_screenwidth()
screen_height = root.winfo_screenheight()

# 计算窗口在屏幕中央的位置
x = (screen_width / 2) - (500 / 2)  # 500是窗口的宽度
y = (screen_height / 2) - (300 / 2)  # 400是窗口的高度

# 设置窗口位置
root.geometry("+{}+{}".format(int(x), int(y)))

# 设置全局字体和颜色
font = ("Helvetica", 12)
bg_color = "#f0f0f0"
fg_color = "#000000"

# 默认值
address_label = tk.Label(root, text="Listen Address:", font=font, bg=bg_color, fg=fg_color)
address_label.grid(row=0, column=1, padx=10, pady=10)
address_entry = tk.Entry(root, font=font)
address_entry.grid(row=0, column=2, padx=10, pady=10)
address_entry.insert(0, get_local_ip())
address_entry.config(state='disabled')  # 禁用Entry

port_label = tk.Label(root, text="Listen Port:", font=font, bg=bg_color, fg=fg_color)
port_label.grid(row=1, column=1, padx=10, pady=10)
port_entry = ttk.Combobox(root, font=font, width=18)
port_entry.grid(row=1, column=2, padx=10, pady=10)
port_entry['values'] = get_port_list()  # 你可以在这个列表中添加你想要的端口号
port_entry.set('')
port_entry.bind('<<ComboboxSelected>>', get_file_wall_list)

protocol_label = tk.Label(root, text="Protocol:", font=font, bg=bg_color, fg=fg_color)
protocol_label.grid(row=2, column=1, padx=10, pady=10)
protocol_entry = tk.Entry(root, font=font)
protocol_entry.grid(row=2, column=2, padx=10, pady=10)
protocol_entry.insert(0, "tcp")

rule_name_label = tk.Label(root, text="Rule Name:", font=font, bg=bg_color, fg=fg_color)
rule_name_label.grid(row=3, column=1, padx=10, pady=10)
rule_name_entry = ttk.Combobox(root, font=font, width=18)
rule_name_entry.grid(row=3, column=2, padx=10, pady=10)

# 添加规则按钮
add_rule_button = tk.Button(root, text="Add Rule", command=add_rule, font=font, bg=bg_color, fg=fg_color)
add_rule_button.grid(row=4, column=1, padx=10, pady=10)

# 删除规则按钮
delete_rule_button = tk.Button(root, text="Delete Rule", command=delete_rule, font=font, bg=bg_color, fg=fg_color)
delete_rule_button.grid(row=4, column=2, padx=10, pady=10)

# 查询规则按钮
query_rule_button = tk.Button(root, text="Query Rule", command=query_proxy_rule, font=font, bg=bg_color, fg=fg_color)
query_rule_button.grid(row=4, column=3, padx=10, pady=10)

root.mainloop()

# 打包
# pyinstaller --onefile --windowed window_port.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    等一分钟 Wait One Minute
    徐誉滕