Gradio如何获取客户端信息
原始app.py应用
我编写了一个简单Gradio应用,代码如下(server_app.py):
import gradio as gr
def gradio_submit_fn(arg1, arg2):
return f"arg1: {arg1}, arg2:{arg2}"
with gr.Blocks() as gradio_app:
samples = gr.Number(label="数字", value=1)
article = gr.Textbox(label="文字")
btn = gr.Button("提交")
output = gr.HTML(label="label", value="value")
btn.click(fn=gradio_submit_fn, inputs=[samples, article], outputs=output)
if __name__ == "__main__":
gradio_port = 9091
gradio_app.launch(server_name='0.0.0.0', server_port=gradio_port, inbrowser=False)
当我访问浏览器的时候,服务端控制台并没有任何输出,我现在想实现如下功能:当浏览器访问 Gradio app时,控制台打印客户端的ip地址等等。
很不幸,Gradio官方并没有提供这个功能,最后想到用flask在中间加一层代理,让flask转发所有的请求,让flask来打印日志,也是一样的。
基本使用
准备好flask的代理(server_app_proxy.py)
import logging
import requests
import threading
from flask import Flask, request, Response, stream_with_context
app = Flask(__name__)
proxy_port = None
gradio_port = None
# 自定义 Flask 打印日志
class StaticFilter(logging.Filter):
def filter(self, record):
line_messages = record.getMessage()
# 不打印如下日志信息
filter_list = ['.css HTTP/', '.js HTTP/']
for f in filter_list:
if f in line_messages:
return False
return True
# Flask 代理服务器路由,捕获所有路径和方法
@app.route('/', defaults={'path': ''}, methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'])
@app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'])
def proxy(path):
req_headers = {key: value for (key, value) in request.headers}
user_agent = req_headers.get('User-Agent', 'No User-Agent')
print(f"User-Agent: {user_agent}")
query_string = request.query_string.decode("utf-8")
base_url = f'http://127.0.0.1:{gradio_port}/{path}' # noqa
full_url = f'{base_url}?{query_string}' if query_string else base_url
req_data = request.get_data()
if req_data != b'':
print("req_data:", req_data)
resp = requests.request(
method=request.method,
url=full_url,
headers=req_headers,
data=req_data,
cookies=request.cookies,
allow_redirects=False,
stream=True,
)
if 'text/event-stream' in resp.headers.get('Content-Type').lower():
def generate():
for chunk in resp.iter_content(chunk_size=1):
yield chunk
return Response(stream_with_context(generate()), content_type='text/event-stream')
else:
resp_headers = {name: value for (name, value) in resp.raw.headers.items()}
response = Response(resp.content, status=resp.status_code, headers=resp_headers)
return response
# 启动 Flask
def flask_run():
app.run(host='0.0.0.0', port=proxy_port, debug=False)
# 后台启动代理
def create_flask_proxy(proxy_port_: int, gradio_port_: int):
global proxy_port, gradio_port
proxy_port, gradio_port = proxy_port_, gradio_port_
log = logging.getLogger('werkzeug')
log.addFilter(StaticFilter())
# 新线程启动
flask_thread = threading.Thread(target=flask_run)
flask_thread.start()
对应的Gradio应用(server_app.py)
import gradio as gr
from server_app_proxy import create_flask_proxy
def gradio_submit_fn(arg1, arg2):
return f"arg1: {arg1}, arg2:{arg2}"
with gr.Blocks() as gradio_app:
samples = gr.Number(label="数字", value=1)
article = gr.Textbox(label="文字")
btn = gr.Button("提交")
output = gr.HTML(label="label", value="value")
btn.click(fn=gradio_submit_fn, inputs=[samples, article], outputs=output)
if __name__ == "__main__":
proxy_port = 9090
gradio_port = 9091
create_flask_proxy(proxy_port, gradio_port)
gradio_app.launch(server_name='0.0.0.0', server_port=gradio_port, inbrowser=False)
现在访问flask启动的端口9090,而不是Gradio的端口,查看日志:
高级使用
如果你的flask需要长时间运行,推荐使用gunicorn,只需替换 server_app_proxy.py 文件
pip install gunicorn
server_app_proxy.py
import requests
from flask import Flask, request, Response, stream_with_context
from gunicorn.app.base import BaseApplication
from multiprocessing import Process
# 定义 gunicorn 服务
class StandaloneApplication(BaseApplication):
def __init__(self, options=None):
self.options = options or {}
super(StandaloneApplication, self).__init__()
def load_config(self):
config = {key: value for key, value in self.options.items()
if key in self.cfg.settings and value is not None}
for key, value in config.items():
self.cfg.set(key.lower(), value)
def load(self):
app = Flask(__name__)
# Flask 代理服务器路由,捕获所有路径和方法
@app.route('/', defaults={'path': ''}, methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'])
@app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'])
def proxy(path):
req_headers = {key: value for (key, value) in request.headers}
query_string = request.query_string.decode("utf-8")
base_url = f'http://127.0.0.1:{self.options["gradio_port"]}/{path}' # noqa
full_url = f'{base_url}?{query_string}' if query_string else base_url
req_data = request.get_data()
if req_data != b'':
print("req_data:", req_data)
resp = requests.request(
method=request.method,
url=full_url,
headers=req_headers,
data=req_data,
cookies=request.cookies,
allow_redirects=False,
stream=True,
)
if 'text/event-stream' in resp.headers.get('Content-Type').lower():
def generate():
for chunk in resp.iter_content(chunk_size=1):
yield chunk
return Response(stream_with_context(generate()), content_type='text/event-stream')
else:
resp_headers = {name: value for (name, value) in resp.raw.headers.items()}
response = Response(resp.content, status=resp.status_code, headers=resp_headers)
return response
return app
def run_server(proxy_port, gradio_port):
options = {
'bind': f'0.0.0.0:{proxy_port}',
'workers': 4,
'proxy_port': proxy_port,
'gradio_port': gradio_port,
'accesslog': '-', # '-' means log to stdout
'errorlog': '-', # '-' means log to stdout
}
StandaloneApplication(options).run()
# 后台启动代理
def create_flask_proxy(proxy_port: int, gradio_port: int):
Process(target=run_server, args=(proxy_port, gradio_port)).start()