Hi,大家好,我是编程小6,很荣幸遇见你,我把这些年在开发过程中遇到的问题或想法写出来,今天说一说基于locust的websocket压测,希望能够帮助你!!!。
背景:
locust默认内部只封装httplocust;使用的是requests中的session进行了封装;如果我想测试其它协议怎么办,比如websocket , grpc;我们只要重写一个实例给client即可:
重写WebSocketClient类(主要用来替换掉self.client的http实例)
class WebSocketClient(object):
def __init__(self, host):
self.host = host
self.ws = websocket.WebSocket()
def connect(self, burl):
start_time = time.time()
try:
self.conn = self.ws.connect(url=burl)
except websocket.WebSocketTimeoutException as e:
total_time = int((time.time() - start_time) * 1000)
events.request_failure.fire(request_type="websockt", name='urlweb', response_time=total_time, exception=e)
else:
total_time = int((time.time() - start_time) * 1000)
events.request_success.fire(request_type="websockt", name='urlweb', response_time=total_time, response_length=0)
return self.conn
def recv(self):
return self.ws.recv()
def send(self, msg):
self.ws.send(msg)
注意:该类中定义了,websocket的常用操作,链接、接收、发送;最主要是events.request_failure.fire和events.request_success.fire这两个用来收集性能数据,如果不写报告收集不到性能数据
2、重写一个HttpLocust类,我们这里叫做WebsoketLoscust类
class WebsocketLocust(Locust):
def __init__(self, *args, **kwargs):
super(WebsocketLocust, self).__init__(*args, **kwargs)
self.client = WebSocketClient(self.host)
注意:WebsocketLocust从Locust继承; 这里主要是将self.client重新实例成,我们第一部写好的websocketClient实例
3、编写TaskSet类
class SupperDianCan(TaskSet):
@task
def test_baidu(self):
self.url = 'wss://xxxxxx.xxxx.com/cart/chat?sid=11303&table_no=103&user_id=ofZjWs40HxEzvV08l6m4PnqGbxqc_2_1_&version=2'
self.data = {}
self.client.connect(self.url)
while True:
recv = self.client.recv()
print(recv)
if eval(recv)['type'] == 'keepalive':
self.client.send(recv)
else:
self.client.send(self.data)
注意:此类就是任务类,跟http的写法一样,只是这里用的self.client.xxxx已经变成了我们自已重写的websocket类,将原来的requests http替换了
4/编写站点类
class WebsiteUser(WebsocketLocust):
task_set = SupperDianCan
min_wait=5000
max_wait=9000
注意:站点类从第二步中的locust继承
完整代码1:
from locust import Locust, events, task, TaskSet
import websocket
import time
import gzip
class WebSocketClient():
def __init__(self, host):
self.host = host
#self.port = port
class WebSocketLocust(Locust):
def __init__(self, *args, **kwargs):
self.client = WebSocketClient("1xx.xx.xx.85")
class UserBehavior(TaskSet):
@task(1)
def buy(self):
try:
ws = websocket.WebSocket()
# self.ws.connect("ws://xx:8807")
ws.connect("ws://xxxx.com/r1/xx/ws")
start_time = time.time()
#self.ws.send('{"url":"/buy","data":{"id":"123","issue":"20170822","doubled_num":2}}')
#result = self.ws.recv()
send_info = '{"sub": "market.ethusdt.kline.1min","id": "id10"}'
# send_info = '{"event":"subscribe", "channel":"btc_usdt.deep"}'
while True:
# time.sleep(5)
# ws.send(json.dumps(send_info))
ws.send(send_info)
while (1):
compressData = ws.recv()
result = gzip.decompress(compressData).decode('utf-8')
if result[:7] == '{"ping"':
ts = result[8:21]
pong = '{"pong":' + ts + '}'
ws.send(pong)
ws.send(send_info)
# else:
# # print(result)
# with open('./test_result.txt', 'a') as f:
# #f.write(threading.currentThread().name + '\n')
# f.write(result + '\n')
except Exception as e:
print("error is:",e)
class ApiUser(WebSocketLocust):
task_set = UserBehavior
min_wait = 100
max_wait = 200
完整代码2:
# -*- encoding:utf-8 -*-
import gzip
import json
import random
import threading
import time
import zlib
from threading import Timer
import websocket
from gevent._semaphore import Semaphore
from locust import TaskSet, task, Locust, events
# TODO: 设置集合点...
all_locusts_spawned = Semaphore()
all_locusts_spawned.acquire()
def on_hatch_complete(**kwargs):
all_locusts_spawned.release()
events.hatch_complete += on_hatch_complete
t2 = 0
repCount = 0
sendCount = 0
pingCount = 0
stSend = 0
openTime = 0
reqLen = 0
recordSt = 0
repList = []
printCount = 1
reqSentCount = 1
symbols = ["etcusdt"]
subbedCount = 0
retSubTopicCount = 0
testFlag = 0
def on_message(ws, message):
global t2
global repCount
global sendCount
global pingCount
global stSend
global printCount
global reqList
global recordSt
global subbedCount
global retSubTopicCount
global reqSentCount
req_list = {
"req_str1": '{"req": "market.%s.kline.1min"}' % random.choice(symbols),
"req_str2": '{"req": "market.%s.depth.step0"}' % random.choice(symbols),
"req_str3": '{"req": "market.%s.trade.detail"}' % random.choice(symbols),
"req_str4": '{"req": "market.%s.detail"}' % random.choice(symbols),
# "req_str5": '{"req": "market.overview"}',
}
# 对返回来的压缩数据进行解压缩
ws_result = zlib.decompressobj(31).decompress(message)
result = json.loads(ws_result.decode('utf-8'))
print(result)
recordEd = time.time() # 为了判断什么时候统计数据的结束时间
recordCost = round((recordEd - recordSt) * 1000, 3) # 统计的结束时间减去统计的开始时间
# print(result)
if 'subbed' in result:
subbedCount = subbedCount + 1
if subbedCount % 5 == 0:
print("----------------subbed all topic----------------")
if 'ch' in result:
retSubTopicCount = retSubTopicCount + 1
if 'rep' in result:
repCount = repCount + 1
repRetTime = int((time.time() - stSend) * 1000)
repList.append(repRetTime)
# print("the server rep time is ---->%dms" % repRetTime)
# print("the server rep data is ---->%s" % result)
# 判断ping的返回 ,对应给服务器发送pong
if 'ping' in result:
pingCount = pingCount + 1
ping_id = result.get('ping')
pong_str = '{"pong": %d}' % ping_id
ws.send(pong_str)
t1 = ping_id
t3 = ping_id - t2
t2 = t1
if t3 > 5000:
print("$$$$$$$time difference ping is %d$$$$$$$ " % t3)
# print("ret ping value %d" % ping_id)
# print("ret ping curTime %d" % int(time.time()*1000))
# if 1000 < int((time.time()*1000) - ping_id):
# print("cur - pingTime is ---> %dms" % int((time.time()*1000) - ping_id))
if recordCost >= (random.randint(2000, 3000) * reqSentCount):
reqSentCount += 1
for key in req_list.keys():
ws.send(req_list[key])
sendCount = sendCount + 1
# print("send req info is --------->", req_list[key])
stSend = time.time()
# print("**********send count is %d *************** " % sendCount)
# 每1分钟统计一次
if recordCost >= (60000 * printCount):
printCount = printCount + 1
curTime = time.strftime('%Y-%m-%d %H:%M:%S')
repList.sort()
retCount = len(repList)
writeData = '| 当前时间%s ,req发送条数%s,返回总数据条数%s | 数据95耗时:%s | 数据50耗时:%s | sub返回量:%s ' % (
curTime, sendCount, repCount, repList[int(retCount * 0.95)], repList[int(retCount * 0.5)], retSubTopicCount)
fid = open("GipRecord.txt", "a+")
fid.write(writeData + "\n")
fid.close()
# 重新实现对应事件
def on_error(ws, error):
print("occur error " + error)
def on_close(ws):
global printCount
global reqSentCount
printCount = 1
reqSentCount = 1
print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^con is closed^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
def on_open(ws):
print("con success ...")
global reqList
global sendCount
global reqLen
global recordSt
global stSend
recordSt = time.time() # 为了统计记录文件创建的开始时间
stSend = time.time()
sub_list = {
"sub_str1": '{"sub": "market.%s.kline.1min"}' % random.choice(symbols),
"sub_str2": '{"sub": "market.%s.depth.step0"}' % random.choice(symbols),
"sub_str3": '{"sub": "market.%s.trade.detail"}' % random.choice(symbols),
"sub_str4": '{"sub": "market.%s.detail"}' % random.choice(symbols),
"sub_str5": '{"sub": "market.overview"}',
}
for key in sub_list.keys():
ws.send(sub_list[key])
class WSClient(object):
def __init__(self, host):
self.ws = None
self.host = host
def useWebCreate(self):
# websocket.enableTrace(True)
self.ws = websocket.WebSocketApp(self.host,
# header={'cloud-exchange':'510a02991'},
on_message=on_message,
on_error=on_error,
on_close=on_close,
on_open=on_open)
def execute(self):
self.ws.run_forever()
class AbstractLocust(Locust):
def __init__(self, *args, **kwargs):
super(AbstractLocust, self).__init__(*args, **kwargs)
self.client = WSClient(self.host)
class ApiUser(AbstractLocust):
host = 'ws://xxx/ws'
min_wait = 10
max_wait = 1000
class task_set(TaskSet):
def on_start(self):
self.client.useWebCreate()
# TODO: 设置集合点...
all_locusts_spawned.wait()
@task
def execute_long_run(self):
self.client.execute()
完整代码3
# -*- encoding:utf-8 -*-
import websocket
import threading
import time
import zlib
import json
from locust import TaskSet, task, User, between
t2 = 0
repCount = 0
sendCount = 0
pingCount = 0
stSend = 0
openTime = 0
reqLen = 0
recordSt = 0
repList = []
printCount = 1
def on_message(ws, message):
global repCount
global sendCount
global pingCount
global stSend
# 对返回来的压缩数据进行解压缩
ws_result = zlib.decompressobj(31).decompress(message)
result = json.loads(ws_result.decode())
# print('result->', result) # 调试使用
if 'rep' in result:
repCount = repCount + 1
retCost = int(round(time.time() * 1000)) - int(result["id"]) # 收到数据的系统时间 - 收到数据里的id
# print("##########send count is %d #############" % sendCount)
print("##########rep count is %d #############" % repCount)
if retCost >= 1000:
curTime = time.strftime('%Y-%m-%d %H:%M:%S')
writeData = '| 当前时间%s ,发送条数%s,返回总数据条数%s | 耗时:%s毫秒 |' % (curTime, sendCount, repCount, retCost)
fid = open("reqRecord.txt", "a+")
fid.write(writeData + "\n")
fid.close()
# 判断ping的返回 ,对应给服务器发送pong
if 'ping' in result:
pingCount = pingCount+1
ping_id = result.get('ping')
pong_str = '{"pong": %d}' % ping_id
ws.send(pong_str)
# 重新实现对应事件
def on_error(ws, error):
print("occur error " + error)
def on_close(ws):
print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^con is closed^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
class newRun(threading.Thread):
def __init__(self, ws):
threading.Thread.__init__(self)
self.ws = ws
def run(self):
global sendCount
global printCount
st = time.time()
while True:
time.sleep(1)
sendCount = sendCount + 1
# everyReqList = ['{"req": "market.etc11pbtc.kline.1min","id":"%d","from":1616774400,"to":1616817600}' % int(round(time.time()*1000))]
everyReqList = ['{"req": "market.etc11pbtc.kline.1min","id":"%d","from":1620717333,"to":1620718333}' % int(
round(time.time() * 1000)),
'{"req": "market.etc12pbtc.kline.5min","id":"%d","from":1620717333,"to":1620718333}' % int(
round(time.time() * 1000)),
'{"req": "market.etc1pbtc.kline.5min","id":"%d","from":1620717333,"to":1620718333}' % int(
round(time.time() * 1000)),
'{"req": "market.etc2pbtc.kline.5min","id":"%d","from":1620717333,"to":1620718333}' % int(
round(time.time() * 1000)),
# '{"req": "market.etc11pbtc.trade.detail","id":"%d"}' % int(
# round(time.time() * 1000))
]
for k in range(0,4):
self.ws.send(everyReqList[k])
ct = int((time.time() - st) * 1000)
if ct >= printCount * 1000:
print("=====send count is %d====" % sendCount)
printCount = printCount + 1
def on_open(ws):
print("con success ...")
t1 = newRun(ws)
t1.start()
class WSClient(object):
def __init__(self, host):
self.ws = None
self.host = host
def useWebCreate(self):
self.ws = websocket.WebSocketApp(self.host,
on_message=on_message,
on_error=on_error,
on_close=on_close,
on_open=on_open)
def execute(self):
self.ws.run_forever()
class AbstractLocust(User):
def __init__(self, *args, **kwargs):
super(AbstractLocust, self).__init__(*args, **kwargs)
self.host = 'ws://xxxx.com:80/ws'
self.client = WSClient(self.host)
wait_time = between(0, 1)
@task
class task_set(TaskSet):
def on_start(self):
self.client.useWebCreate()
@task
def execute_long_run(self):
self.client.execute()
完整代码4
# -*- encoding:utf-8 -*-
import base64
import configparser
import hashlib
import hmac
import json
import random
import time
from datetime import datetime
from urllib import parse
import linecache
import random
from locust import HttpUser
import websocket
import xlrd
from gevent._semaphore import Semaphore
# from locust import TaskSet, task, Locust,events
from locust import TaskSet, task, User,between
# # TODO: 设置集合点...
# all_locusts_spawned = Semaphore()
# all_locusts_spawned.acquire()
#
#
# def on_hatch_complete(**kwargs):
# all_locusts_spawned.release()
#
# events.hatch_complete += on_hatch_complete
conf = configparser.ConfigParser()
# # conf.read("domain.ini")
# # domain = conf.get('DevSet','ip')
domain = 'user-data-push.loadtest-5.hk3.huobiapps.com'
# file_name = "tokenads.txt"
t2 = 0
repCount = 0
sendCount = 0
pingCount = 0
subRepCount = 0
row = 0
recordSt = 0
printCount = 1
subOneMinCount = 0
subSymbolCount = 0
sdTime = 0
endTime = 0
req2002 = 0
startResultTime = 0
endResultTime = 0
resultTimeList = []
# print("+++++++++" + domain)
def on_message(ws, message):
global t2
global startResultTime
global endResultTime
global repCount
global sendCount
global pingCount
global subRepCount
global recordSt
global printCount
global subOneMinCount
global subSymbolCount
global endTime
global if_write
global req2002
global resultTimeList
# 对返回来的压缩数据进行解压缩
result = json.loads(message)
print("---- %s" % result)
sub_list = {
"sub_str1": '{"action": "sub","ch": "accounts.update#1"}',
"sub_str2": '{"action": "sub","ch": "trade.clearing#*#0"}',#成交
"sub_str4": '{"action": "sub","ch": "trade.clearing#*#1"}',
"sub_str3": '{"action": "sub","ch": "orders#*"}', #
# "sub_str2": '{"action": "sub","ch": "trade.clearing#*#0"}', # 成交
# "sub_str4": '{"action": "sub","ch": "trade.clearing#*#1"}',
# "sub_str2": '{"action": "sub","ch": "trade.clearing#*#0"}', # 成交
# "sub_str4": '{"action": "sub","ch": "trade.clearing#*#1"}',
}
# 判断ping的返回 ,对应给服务器发送pong
if 'action' in result:
if 'ch' in result:
if result['ch'] == 'auth' and result["code"] == 200:
print("auth success")
for key in sub_list.keys():
startResultTime = time.time()
# print("发送前时间是 %s" % startResultTime)
a = random.sample(sub_list.keys(), 1)
print("************",a)
# print(a)
ws.send(sub_list[a[0]])
break
# ws.send(sub_list[key])
# print(sub_list)
else:
pingCount = pingCount + 1
ping_id = result['data']['ts']
# print(ping_id)
pong_str = '{"action":"pong","params":{"ts": %d}}' % ping_id
ws.send(pong_str)
# print("send pong info is -->" + pong_str)
t1 = ping_id
t3 = ping_id - t2
t2 = t1
if t3 > 5000:
print("$$$$$$$time difference ping is %d$$$$$$$ " % t3)
# if subRepCount % 12 == 0:
# print("Receive sub count is %d" % subRepCount)
# 每1分钟统计一次
recordEd = time.time()
recordCost = round((recordEd - recordSt) * 1000, 3) # 统计的结束时间减去统计的开始时间
if recordCost >= (60000 * 5 * printCount):
printCount = printCount + 1
resultTimeList.sort()
retCount = len(resultTimeList)
avg_time = round(sum(resultTimeList) / retCount)
curTime = time.strftime('%Y-%m-%d %H:%M:%S')
writeData = '| 当前时间%s ,当前存在%s条数据,|平均耗时 : %s | 数据70耗时:%s | 数据95耗时:%s | 数据99耗时:%s |' % (
curTime, retCount, avg_time, resultTimeList[int(retCount * 0.7)], resultTimeList[int(retCount * 0.95)],
resultTimeList[int(retCount * 0.99)])
fid = open("account-record.txt", "a+")
fid.write(writeData + "\n")
fid.close()
# 每分钟重置一次
subOneMinCount = 0
# 重新实现对应事件
def on_error(ws, error):
print("occur error %s" % error)
def on_close(ws):
global printCount
printCount = 1
print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^con is closed^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
def on_open(ws):
global recordSt
global sdTime
recordSt = time.time()
print("con success ...")
# # 加签
file_name = "loadtest-5-UserDataPush"
# count = len(open(file_name, 'rU').readlines()) # 获取行数
count = len(open(file_name, 'rU').readlines()) # 获取行数
print("当前有%s行 " % count)
random_col = random.randrange(1, count, 1)
print("当前随机取到%s行" % random_col)
line = linecache.getline(file_name, random_col)
print(line)
data = line.strip().split(",")
access_key = data[0]
secret_key = data[1]
##test2
# access_key = "364ed29f-nbtycf4rw2-2c67cc5d-e2814"
# secret_key = "5b02fcf3-500608e9-b561af74-cfd8d"
print("--------" + access_key)
print("++++++" + secret_key)
authData = [
secret_key.encode('utf-8'),
{
"authType": "api",
"accessKey": access_key,
"signatureMethod": "HmacSHA256",
"signatureVersion": "2.1",
"timestamp": authfunc()._utc()
},
]
send_msg = authfunc()._auth(authData)
sendData = {"action": "req", "ch": "auth", "params": send_msg}
print("+++++++++")
print(json.dumps(sendData))
ws.send(json.dumps(sendData))
class WSClient(object):
def __init__(self, host):
self.ws = None
self.host = host
def useWebCreate(self):
# websocket.enableTrace(True)
self.ws = websocket.WebSocketApp(self.host,
on_message=on_message,
on_error=on_error,
on_close=on_close,
on_open=on_open,
header=["X-HB-Exchange-Code:pro"])
def execute(self):
self.ws.run_forever()
class AbstractLocust(User):
def __init__(self, *args, **kwargs):
super(AbstractLocust, self).__init__(*args, **kwargs)
domain='xx.huobiapps.com'
self.host = 'ws://' + domain + '/ws/v2'
self.client = WSClient(self.host)
wait_time = between(0, 1)
# print("%s" %self.client)
# class ApiUser(AbstractLocust):
# class ApiUser(AbstractLocust):
#
# host = 'ws://' + domain + '/ws/v2'
# print("++++++host is %s" % host)
# min_wait = 10
# max_wait = 1000
# wait_time = between(0, 1)
@task
class task_set(TaskSet):
def on_start(self):
self.client.useWebCreate()
# TODO: 设置集合点...
# all_locusts_spawned.wait()
@task
def execute_long_run(self):
self.client.execute()
# 加签
class authfunc(object):
def _sign(self, param=None, _accessKeySecret=None):
# create signature:
if param is None:
params = {}
params = {}
# params['signatureMethod'] = param.get('signatureMethod') if type(param.get('signatureMethod')) == type(
# 'a') else '' if param.get('signatureMethod') else ''
# params['signatureVersion'] = param.get('signatureVersion') if type(param.get('signatureVersion')) == type(
# 'a') else '' if param.get('signatureVersion') else ''
# params['accessKey'] = param.get('accessKey') if type(param.get('accessKey')) == type(
# 'a') else '' if param.get('accessKey') else ''
# params['timestamp'] = param.get('timestamp') if type(param.get('timestamp')) == type('a') else '' if param.get(
# 'timestamp') else ''
params['signatureMethod'] = param.get('signatureMethod') if type(param.get('signatureMethod')) == type(
'a') else '' if param.get('signatureMethod') else ''
params['signatureVersion'] = param.get('signatureVersion') if type(param.get('signatureVersion')) == type(
'a') else '' if param.get('signatureVersion') else ''
params['accessKey'] = param.get('accessKey') if type(param.get('accessKey')) == type(
'a') else '' if param.get('accessKey') else ''
params['timestamp'] = param.get('timestamp') if type(param.get('timestamp')) == type('a') else '' if param.get(
'timestamp') else ''
# print(params)
# sort by key:
keys = sorted(params.keys())
_host = domain
path = '/ws/v2'
# build query string like: a=1&b=%20&c=:
qs = '&'.join(['%s=%s' % (key, self._encode(params[key])) for key in keys])
# build payload:
payload = '%s\n%s\n%s\n%s' % ('GET', _host, path, qs)
# print(payload)
# print('payload:\n%s' % payload)
dig = hmac.new(_accessKeySecret, msg=payload.encode('utf-8'), digestmod=hashlib.sha256).digest()
return base64.b64encode(dig).decode()
def _encode(self, s):
# return urllib.pathname2url(s)
return parse.quote(s, safe='')
def _utc(self):
return datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S')
def _auth(self, auth):
authenticaton_data = auth[1]
_accessKeySecret = auth[0]
authenticaton_data['signature'] = self._sign(authenticaton_data, _accessKeySecret)
# print(authenticaton_data)
return authenticaton_data
然后通过locust命令执行
locust -f xx.py --no-web -c 2 -r 1 -t 1m
今天的分享到此就结束了,感谢您的阅读,如果确实帮到您,您可以动动手指转发给其他人。
上一篇
已是最后文章
下一篇
已是最新文章