已复制
全屏展示
复制代码

Gradio如何获取客户端信息


· 4 min read

原始app.py应用

我编写了一个简单Gradio应用,代码如下(server_app.py):

import gradio as gr


def gradio_submit_fn(arg1, arg2):
    return f"arg1: {arg1}, arg2:{arg2}"


with gr.Blocks() as gradio_app:
    samples = gr.Number(label="数字", value=1)
    article = gr.Textbox(label="文字")
    btn = gr.Button("提交")
    output = gr.HTML(label="label", value="value")
    btn.click(fn=gradio_submit_fn, inputs=[samples, article], outputs=output)

if __name__ == "__main__":
    gradio_port = 9091
    gradio_app.launch(server_name='0.0.0.0', server_port=gradio_port, inbrowser=False)

当我访问浏览器的时候,服务端控制台并没有任何输出,我现在想实现如下功能:当浏览器访问 Gradio app时,控制台打印客户端的ip地址等等。

很不幸,Gradio官方并没有提供这个功能,最后想到用flask在中间加一层代理,让flask转发所有的请求,让flask来打印日志,也是一样的。

基本使用

准备好flask的代理(server_app_proxy.py)

import logging
import requests
import threading
from flask import Flask, request, Response, stream_with_context

app = Flask(__name__)
proxy_port = None
gradio_port = None


# 自定义 Flask 打印日志
class StaticFilter(logging.Filter):
    def filter(self, record):
        line_messages = record.getMessage()
        # 不打印如下日志信息
        filter_list = ['.css HTTP/', '.js HTTP/']
        for f in filter_list:
            if f in line_messages:
                return False
        return True


# Flask 代理服务器路由,捕获所有路径和方法
@app.route('/', defaults={'path': ''}, methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'])
@app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'])
def proxy(path):
    req_headers = {key: value for (key, value) in request.headers}
    user_agent = req_headers.get('User-Agent', 'No User-Agent')
    print(f"User-Agent: {user_agent}")
    query_string = request.query_string.decode("utf-8")
    base_url = f'http://127.0.0.1:{gradio_port}/{path}'  # noqa
    full_url = f'{base_url}?{query_string}' if query_string else base_url
    req_data = request.get_data()
    if req_data != b'':
        print("req_data:", req_data)
    resp = requests.request(
        method=request.method,
        url=full_url,
        headers=req_headers,
        data=req_data,
        cookies=request.cookies,
        allow_redirects=False,
        stream=True,
    )
    if 'text/event-stream' in resp.headers.get('Content-Type').lower():
        def generate():
            for chunk in resp.iter_content(chunk_size=1):
                yield chunk

        return Response(stream_with_context(generate()), content_type='text/event-stream')
    else:
        resp_headers = {name: value for (name, value) in resp.raw.headers.items()}
        response = Response(resp.content, status=resp.status_code, headers=resp_headers)
        return response


# 启动 Flask
def flask_run():
    app.run(host='0.0.0.0', port=proxy_port, debug=False)


# 后台启动代理
def create_flask_proxy(proxy_port_: int, gradio_port_: int):
    global proxy_port, gradio_port
    proxy_port, gradio_port = proxy_port_, gradio_port_
    log = logging.getLogger('werkzeug')
    log.addFilter(StaticFilter())
    # 新线程启动
    flask_thread = threading.Thread(target=flask_run)
    flask_thread.start()

对应的Gradio应用(server_app.py)

import gradio as gr

from server_app_proxy import create_flask_proxy


def gradio_submit_fn(arg1, arg2):
    return f"arg1: {arg1}, arg2:{arg2}"


with gr.Blocks() as gradio_app:
    samples = gr.Number(label="数字", value=1)
    article = gr.Textbox(label="文字")
    btn = gr.Button("提交")
    output = gr.HTML(label="label", value="value")
    btn.click(fn=gradio_submit_fn, inputs=[samples, article], outputs=output)


if __name__ == "__main__":
    proxy_port = 9090
    gradio_port = 9091
    create_flask_proxy(proxy_port, gradio_port)
    gradio_app.launch(server_name='0.0.0.0', server_port=gradio_port, inbrowser=False)

现在访问flask启动的端口9090,而不是Gradio的端口,查看日志:

高级使用

如果你的flask需要长时间运行,推荐使用gunicorn,只需替换 server_app_proxy.py 文件

pip install gunicorn

server_app_proxy.py

import requests
from flask import Flask, request, Response, stream_with_context
from gunicorn.app.base import BaseApplication
from multiprocessing import Process


# 定义 gunicorn 服务
class StandaloneApplication(BaseApplication):
    def __init__(self, options=None):
        self.options = options or {}
        super(StandaloneApplication, self).__init__()

    def load_config(self):
        config = {key: value for key, value in self.options.items()
                  if key in self.cfg.settings and value is not None}
        for key, value in config.items():
            self.cfg.set(key.lower(), value)

    def load(self):
        app = Flask(__name__)

        # Flask 代理服务器路由,捕获所有路径和方法
        @app.route('/', defaults={'path': ''}, methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'])
        @app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'])
        def proxy(path):
            req_headers = {key: value for (key, value) in request.headers}
            query_string = request.query_string.decode("utf-8")
            base_url = f'http://127.0.0.1:{self.options["gradio_port"]}/{path}'  # noqa
            full_url = f'{base_url}?{query_string}' if query_string else base_url
            req_data = request.get_data()
            if req_data != b'':
                print("req_data:", req_data)
            resp = requests.request(
                method=request.method,
                url=full_url,
                headers=req_headers,
                data=req_data,
                cookies=request.cookies,
                allow_redirects=False,
                stream=True,
            )
            if 'text/event-stream' in resp.headers.get('Content-Type').lower():
                def generate():
                    for chunk in resp.iter_content(chunk_size=1):
                        yield chunk

                return Response(stream_with_context(generate()), content_type='text/event-stream')
            else:
                resp_headers = {name: value for (name, value) in resp.raw.headers.items()}
                response = Response(resp.content, status=resp.status_code, headers=resp_headers)
                return response

        return app


def run_server(proxy_port, gradio_port):
    options = {
        'bind': f'0.0.0.0:{proxy_port}',
        'workers': 4,
        'proxy_port': proxy_port,
        'gradio_port': gradio_port,
        'accesslog': '-',  # '-' means log to stdout
        'errorlog': '-',  # '-' means log to stdout
    }
    StandaloneApplication(options).run()


# 后台启动代理
def create_flask_proxy(proxy_port: int, gradio_port: int):
    Process(target=run_server, args=(proxy_port, gradio_port)).start()
🔗

文章推荐