基于locust的websocket压测

(3) 2024-05-05 18:23

Hi,大家好,我是编程小6,很荣幸遇见你,我把这些年在开发过程中遇到的问题或想法写出来,今天说一说基于locust的websocket压测,希望能够帮助你!!!。

背景:

locust默认内部只封装httplocust;使用的是requests中的session进行了封装;如果我想测试其它协议怎么办,比如websocket  , grpc;我们只要重写一个实例给client即可:

重写WebSocketClient类(主要用来替换掉self.client的http实例)
class WebSocketClient(object):

    def __init__(self, host):
        self.host = host
        self.ws = websocket.WebSocket()

    def connect(self, burl):
        start_time = time.time()
        try:
            self.conn = self.ws.connect(url=burl)
        except websocket.WebSocketTimeoutException as e:
            total_time = int((time.time() - start_time) * 1000)
            events.request_failure.fire(request_type="websockt", name='urlweb', response_time=total_time, exception=e)
        else:
            total_time = int((time.time() - start_time) * 1000)
            events.request_success.fire(request_type="websockt", name='urlweb', response_time=total_time, response_length=0)
        return self.conn

    def recv(self):
        return self.ws.recv()

    def send(self, msg):
        self.ws.send(msg)

注意:该类中定义了,websocket的常用操作,链接、接收、发送;最主要是events.request_failure.fire和events.request_success.fire这两个用来收集性能数据,如果不写报告收集不到性能数据


2、重写一个HttpLocust类,我们这里叫做WebsoketLoscust类
class WebsocketLocust(Locust):
    def __init__(self, *args, **kwargs):
        super(WebsocketLocust, self).__init__(*args, **kwargs)
        self.client = WebSocketClient(self.host)

注意:WebsocketLocust从Locust继承; 这里主要是将self.client重新实例成,我们第一部写好的websocketClient实例


3、编写TaskSet类

class SupperDianCan(TaskSet):

    @task
    def test_baidu(self):
        self.url = 'wss://xxxxxx.xxxx.com/cart/chat?sid=11303&table_no=103&user_id=ofZjWs40HxEzvV08l6m4PnqGbxqc_2_1_&version=2'

        self.data = {}

        self.client.connect(self.url)
        while True:
            recv = self.client.recv()
            print(recv)
            if eval(recv)['type'] == 'keepalive':
                self.client.send(recv)
            else:
                self.client.send(self.data)

注意:此类就是任务类,跟http的写法一样,只是这里用的self.client.xxxx已经变成了我们自已重写的websocket类,将原来的requests http替换了


4/编写站点类
class WebsiteUser(WebsocketLocust):

    task_set = SupperDianCan

    min_wait=5000

    max_wait=9000

注意:站点类从第二步中的locust继承

完整代码1:

from locust import Locust, events, task, TaskSet

import websocket

import time

import gzip

 

class WebSocketClient():

     def __init__(self, host):

         self.host = host

         #self.port = port

 

class WebSocketLocust(Locust):

     def __init__(self, *args, **kwargs):

         self.client = WebSocketClient("1xx.xx.xx.85")

 

class UserBehavior(TaskSet):



 

     @task(1)

     def buy(self):

         try:

            ws = websocket.WebSocket()

             # self.ws.connect("ws://xx:8807")

             ws.connect("ws://xxxx.com/r1/xx/ws")

 

             start_time = time.time()

 

             #self.ws.send('{"url":"/buy","data":{"id":"123","issue":"20170822","doubled_num":2}}')

            #result = self.ws.recv()

 

            send_info = '{"sub": "market.ethusdt.kline.1min","id": "id10"}'

             # send_info = '{"event":"subscribe", "channel":"btc_usdt.deep"}'

             while True:

                 # time.sleep(5)

                # ws.send(json.dumps(send_info))

                 ws.send(send_info)

                 while (1):

                    compressData = ws.recv()

                    result = gzip.decompress(compressData).decode('utf-8')

                     if result[:7] == '{"ping"':

                        ts = result[8:21]

                         pong = '{"pong":' + ts + '}'

                         ws.send(pong)

                        ws.send(send_info)

                     # else:

                    #     # print(result)

                    #     with open('./test_result.txt', 'a') as f:

                    #         #f.write(threading.currentThread().name + '\n')

                    #         f.write(result + '\n')

         except Exception as e:

             print("error is:",e)

 

class ApiUser(WebSocketLocust):

    task_set = UserBehavior

     min_wait = 100

     max_wait = 200

完整代码2:

# -*- encoding:utf-8 -*-

import gzip
import json
import random
import threading
import time
import zlib
from threading import Timer

import websocket
from gevent._semaphore import Semaphore
from locust import TaskSet, task, Locust, events

# TODO: 设置集合点...
all_locusts_spawned = Semaphore()
all_locusts_spawned.acquire()


def on_hatch_complete(**kwargs):
    all_locusts_spawned.release()


events.hatch_complete += on_hatch_complete

t2 = 0
repCount = 0
sendCount = 0
pingCount = 0
stSend = 0
openTime = 0
reqLen = 0
recordSt = 0
repList = []
printCount = 1
reqSentCount = 1

symbols = ["etcusdt"]

subbedCount = 0
retSubTopicCount = 0
testFlag = 0

def on_message(ws, message):
    global t2
    global repCount
    global sendCount
    global pingCount
    global stSend
    global printCount
    global reqList
    global recordSt
    global subbedCount
    global retSubTopicCount
    global reqSentCount

    req_list = {

        "req_str1": '{"req": "market.%s.kline.1min"}' % random.choice(symbols),
        "req_str2": '{"req": "market.%s.depth.step0"}' % random.choice(symbols),
        "req_str3": '{"req": "market.%s.trade.detail"}' % random.choice(symbols),
        "req_str4": '{"req": "market.%s.detail"}' % random.choice(symbols),
        # "req_str5": '{"req": "market.overview"}',

    }



    # 对返回来的压缩数据进行解压缩
    ws_result = zlib.decompressobj(31).decompress(message)

    result = json.loads(ws_result.decode('utf-8'))
    print(result)

    recordEd = time.time()  # 为了判断什么时候统计数据的结束时间

    recordCost = round((recordEd - recordSt) * 1000, 3)  # 统计的结束时间减去统计的开始时间

    # print(result)

    if 'subbed' in result:

        subbedCount = subbedCount + 1

        if subbedCount % 5 == 0:
            print("----------------subbed all topic----------------")

    if 'ch' in result:
        retSubTopicCount = retSubTopicCount + 1

    if 'rep' in result:
        repCount = repCount + 1

        repRetTime = int((time.time() - stSend) * 1000)

        repList.append(repRetTime)

        # print("the server rep time is ---->%dms" % repRetTime)

        # print("the server rep data is ---->%s" % result)



    # 判断ping的返回 ,对应给服务器发送pong
    if 'ping' in result:
        pingCount = pingCount + 1
        ping_id = result.get('ping')
        pong_str = '{"pong": %d}' % ping_id
        ws.send(pong_str)

        t1 = ping_id

        t3 = ping_id - t2

        t2 = t1

        if t3 > 5000:
            print("$$$$$$$time difference ping is %d$$$$$$$ " % t3)

        # print("ret ping value %d" % ping_id)
        # print("ret ping curTime %d" % int(time.time()*1000))
        # if 1000 < int((time.time()*1000) - ping_id):
        #     print("cur - pingTime is  ---> %dms" % int((time.time()*1000) - ping_id))

    if recordCost >= (random.randint(2000, 3000) * reqSentCount):
        reqSentCount += 1
        for key in req_list.keys():
            ws.send(req_list[key])

            sendCount = sendCount + 1

            # print("send  req info is --------->", req_list[key])

            stSend = time.time()

        # print("**********send count is %d   *************** " % sendCount)

    # 每1分钟统计一次
    if recordCost >= (60000 * printCount):
        printCount = printCount + 1

        curTime = time.strftime('%Y-%m-%d %H:%M:%S')

        repList.sort()

        retCount = len(repList)

        writeData = '| 当前时间%s ,req发送条数%s,返回总数据条数%s |  数据95耗时:%s  | 数据50耗时:%s  | sub返回量:%s ' % (
        curTime, sendCount, repCount, repList[int(retCount * 0.95)], repList[int(retCount * 0.5)], retSubTopicCount)

        fid = open("GipRecord.txt", "a+")

        fid.write(writeData + "\n")

        fid.close()

# 重新实现对应事件
def on_error(ws, error):
    print("occur error " + error)


def on_close(ws):
    global printCount
    global reqSentCount
    printCount = 1
    reqSentCount = 1
    print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^con is closed^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")


def on_open(ws):
    print("con success ...")
    global reqList
    global sendCount
    global reqLen
    global recordSt
    global stSend

    recordSt = time.time()  # 为了统计记录文件创建的开始时间

    stSend = time.time()

    sub_list = {

        "sub_str1": '{"sub": "market.%s.kline.1min"}' % random.choice(symbols),
        "sub_str2": '{"sub": "market.%s.depth.step0"}' % random.choice(symbols),
        "sub_str3": '{"sub": "market.%s.trade.detail"}' % random.choice(symbols),
        "sub_str4": '{"sub": "market.%s.detail"}' % random.choice(symbols),
        "sub_str5": '{"sub": "market.overview"}',

    }

    for key in sub_list.keys():
        ws.send(sub_list[key])


class WSClient(object):

    def __init__(self, host):
        self.ws = None
        self.host = host

    def useWebCreate(self):
        # websocket.enableTrace(True)
        self.ws = websocket.WebSocketApp(self.host,
                                         # header={'cloud-exchange':'510a02991'},
                                         on_message=on_message,
                                         on_error=on_error,
                                         on_close=on_close,
                                         on_open=on_open)

    def execute(self):
        self.ws.run_forever()


class AbstractLocust(Locust):
    def __init__(self, *args, **kwargs):
        super(AbstractLocust, self).__init__(*args, **kwargs)
        self.client = WSClient(self.host)


class ApiUser(AbstractLocust):
    host = 'ws://xxx/ws'
   
    min_wait = 10
    max_wait = 1000

    class task_set(TaskSet):
        def on_start(self):
            self.client.useWebCreate()
            # TODO: 设置集合点...
            all_locusts_spawned.wait()

        @task
        def execute_long_run(self):
            self.client.execute()

完整代码3

# -*- encoding:utf-8 -*-

import websocket
import threading
import time
import zlib
import json
from locust import TaskSet, task, User, between



t2 = 0
repCount = 0
sendCount = 0
pingCount = 0
stSend = 0
openTime = 0
reqLen = 0
recordSt = 0
repList = []
printCount = 1


def on_message(ws, message):

    global repCount
    global sendCount
    global pingCount
    global stSend
    # 对返回来的压缩数据进行解压缩
    ws_result = zlib.decompressobj(31).decompress(message)

    result = json.loads(ws_result.decode())

    # print('result->', result)   # 调试使用

    if 'rep' in result:

        repCount = repCount + 1

        retCost = int(round(time.time() * 1000)) - int(result["id"])   # 收到数据的系统时间 - 收到数据里的id

        # print("##########send count is  %d   #############" % sendCount)
        print("##########rep  count is  %d   #############" % repCount)

        if retCost >= 1000:

            curTime = time.strftime('%Y-%m-%d %H:%M:%S')

            writeData = '| 当前时间%s ,发送条数%s,返回总数据条数%s |  耗时:%s毫秒  |' % (curTime, sendCount, repCount, retCost)

            fid = open("reqRecord.txt", "a+")

            fid.write(writeData + "\n")

            fid.close()

    # 判断ping的返回 ,对应给服务器发送pong
    if 'ping' in result:

        pingCount = pingCount+1

        ping_id = result.get('ping')

        pong_str = '{"pong": %d}' % ping_id

        ws.send(pong_str)

# 重新实现对应事件
def on_error(ws, error):
    print("occur error " + error)


def on_close(ws):
    print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^con is closed^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")


class newRun(threading.Thread):

    def __init__(self, ws):

        threading.Thread.__init__(self)
        self.ws = ws

    def run(self):

        global sendCount
        global printCount

        st = time.time()

        while True:

            time.sleep(1)

            sendCount = sendCount + 1

            # everyReqList = ['{"req": "market.etc11pbtc.kline.1min","id":"%d","from":1616774400,"to":1616817600}' % int(round(time.time()*1000))]
            everyReqList = ['{"req": "market.etc11pbtc.kline.1min","id":"%d","from":1620717333,"to":1620718333}' % int(
                round(time.time() * 1000)),
                            '{"req": "market.etc12pbtc.kline.5min","id":"%d","from":1620717333,"to":1620718333}' % int(
                                round(time.time() * 1000)),
                            '{"req": "market.etc1pbtc.kline.5min","id":"%d","from":1620717333,"to":1620718333}' % int(
                                round(time.time() * 1000)),
                            '{"req": "market.etc2pbtc.kline.5min","id":"%d","from":1620717333,"to":1620718333}' % int(
                                round(time.time() * 1000)),
                            # '{"req": "market.etc11pbtc.trade.detail","id":"%d"}' % int(
                            #     round(time.time() * 1000))
                            ]
            for k in range(0,4):
                self.ws.send(everyReqList[k])

            ct = int((time.time() - st) * 1000)

            if ct >= printCount * 1000:
                print("=====send count is %d====" % sendCount)
                printCount = printCount + 1



def on_open(ws):

    print("con success ...")

    t1 = newRun(ws)

    t1.start()



class WSClient(object):

    def __init__(self, host):
        self.ws = None
        self.host = host


    def useWebCreate(self):
        self.ws = websocket.WebSocketApp(self.host,
                                    on_message=on_message,
                                    on_error=on_error,
                                    on_close=on_close,
                                    on_open=on_open)

    def execute(self):

        self.ws.run_forever()


class AbstractLocust(User):
    def __init__(self, *args, **kwargs):
        super(AbstractLocust, self).__init__(*args, **kwargs)
        self.host = 'ws://xxxx.com:80/ws'
        self.client = WSClient(self.host)

    wait_time = between(0, 1)

    @task
    class task_set(TaskSet):
        def on_start(self):
            self.client.useWebCreate()

        @task
        def execute_long_run(self):
            self.client.execute()

完整代码4

# -*- encoding:utf-8 -*-

import base64
import configparser
import hashlib
import hmac
import json
import random
import time
from datetime import datetime
from urllib import parse
import linecache
import random
from locust import HttpUser

import websocket
import xlrd
from gevent._semaphore import Semaphore
# from locust import TaskSet, task, Locust,events
from locust import TaskSet, task, User,between

# # TODO: 设置集合点...
# all_locusts_spawned = Semaphore()
# all_locusts_spawned.acquire()
#
#
# def on_hatch_complete(**kwargs):
#     all_locusts_spawned.release()

#
# events.hatch_complete += on_hatch_complete

conf = configparser.ConfigParser()
# # conf.read("domain.ini")
# # domain = conf.get('DevSet','ip')
domain = 'user-data-push.loadtest-5.hk3.huobiapps.com'
# file_name = "tokenads.txt"

t2 = 0
repCount = 0
sendCount = 0
pingCount = 0
subRepCount = 0
row = 0
recordSt = 0
printCount = 1
subOneMinCount = 0
subSymbolCount = 0
sdTime = 0
endTime = 0
req2002 = 0
startResultTime = 0
endResultTime = 0
resultTimeList = []




# print("+++++++++" + domain)
def on_message(ws, message):
    global t2
    global startResultTime
    global endResultTime
    global repCount
    global sendCount
    global pingCount
    global subRepCount
    global recordSt
    global printCount
    global subOneMinCount
    global subSymbolCount
    global endTime
    global if_write
    global req2002
    global resultTimeList

    # 对返回来的压缩数据进行解压缩

    result = json.loads(message)

    print("---- %s" % result)


    sub_list = {
        "sub_str1": '{"action": "sub","ch": "accounts.update#1"}',
        "sub_str2": '{"action": "sub","ch": "trade.clearing#*#0"}',#成交
        "sub_str4": '{"action": "sub","ch": "trade.clearing#*#1"}',
        "sub_str3": '{"action": "sub","ch": "orders#*"}',  #
        # "sub_str2": '{"action": "sub","ch": "trade.clearing#*#0"}',  # 成交
        # "sub_str4": '{"action": "sub","ch": "trade.clearing#*#1"}',
        # "sub_str2": '{"action": "sub","ch": "trade.clearing#*#0"}',  # 成交
        # "sub_str4": '{"action": "sub","ch": "trade.clearing#*#1"}',
    }

    # 判断ping的返回 ,对应给服务器发送pong
    if 'action' in result:
        if 'ch' in result:
            if result['ch'] == 'auth' and result["code"] == 200:
                print("auth success")

                for key in sub_list.keys():
                    startResultTime = time.time()
                    # print("发送前时间是 %s" % startResultTime)
                    a = random.sample(sub_list.keys(), 1)
                    print("************",a)
                    # print(a)
                    ws.send(sub_list[a[0]])
                    break
                    # ws.send(sub_list[key])
                    # print(sub_list)

        else:
            pingCount = pingCount + 1

            ping_id = result['data']['ts']

            # print(ping_id)

            pong_str = '{"action":"pong","params":{"ts": %d}}' % ping_id

            ws.send(pong_str)
            # print("send pong info is -->" + pong_str)

            t1 = ping_id

            t3 = ping_id - t2

            t2 = t1

            if t3 > 5000:
                print("$$$$$$$time difference ping is %d$$$$$$$ " % t3)

        # if subRepCount % 12 == 0:
        #     print("Receive  sub count is %d" % subRepCount)

    # 每1分钟统计一次
    recordEd = time.time()
    recordCost = round((recordEd - recordSt) * 1000, 3)  # 统计的结束时间减去统计的开始时间

    if recordCost >= (60000 * 5 * printCount):
        printCount = printCount + 1
        resultTimeList.sort()
        retCount = len(resultTimeList)
        avg_time = round(sum(resultTimeList) / retCount)

        curTime = time.strftime('%Y-%m-%d %H:%M:%S')

        writeData = '| 当前时间%s ,当前存在%s条数据,|平均耗时 : %s | 数据70耗时:%s  |   数据95耗时:%s  | 数据99耗时:%s  |' % (
            curTime, retCount, avg_time, resultTimeList[int(retCount * 0.7)], resultTimeList[int(retCount * 0.95)],
            resultTimeList[int(retCount * 0.99)])


        fid = open("account-record.txt", "a+")

        fid.write(writeData + "\n")

        fid.close()

        # 每分钟重置一次
        subOneMinCount = 0


# 重新实现对应事件
def on_error(ws, error):
    print("occur error %s" % error)


def on_close(ws):
    global printCount
    printCount = 1
    print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^con is closed^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")


def on_open(ws):
    global recordSt
    global sdTime

    recordSt = time.time()

    print("con success ...")
    # # 加签
    file_name = "loadtest-5-UserDataPush"
    # count = len(open(file_name, 'rU').readlines())  # 获取行数
    count = len(open(file_name, 'rU').readlines())  # 获取行数
    print("当前有%s行 " % count)
    random_col = random.randrange(1, count, 1)
    print("当前随机取到%s行" % random_col)
    line = linecache.getline(file_name, random_col)
    print(line)
    data = line.strip().split(",")
    access_key = data[0]
    secret_key = data[1]


    ##test2
    # access_key = "364ed29f-nbtycf4rw2-2c67cc5d-e2814"
    # secret_key = "5b02fcf3-500608e9-b561af74-cfd8d"
    print("--------" + access_key)
    print("++++++" + secret_key)
    authData = [
        secret_key.encode('utf-8'),
        {
            "authType": "api",
            "accessKey": access_key,
            "signatureMethod": "HmacSHA256",
            "signatureVersion": "2.1",
            "timestamp": authfunc()._utc()
        },

    ]

    send_msg = authfunc()._auth(authData)

    sendData = {"action": "req", "ch": "auth", "params": send_msg}

    print("+++++++++")

    print(json.dumps(sendData))

    ws.send(json.dumps(sendData))


class WSClient(object):

    def __init__(self, host):
        self.ws = None
        self.host = host

    def useWebCreate(self):
        # websocket.enableTrace(True)
        self.ws = websocket.WebSocketApp(self.host,
                                         on_message=on_message,
                                         on_error=on_error,
                                         on_close=on_close,
                                         on_open=on_open,
                                         header=["X-HB-Exchange-Code:pro"])
    def execute(self):
        self.ws.run_forever()




class AbstractLocust(User):
    def __init__(self, *args, **kwargs):
        super(AbstractLocust, self).__init__(*args, **kwargs)
        domain='xx.huobiapps.com'
        self.host = 'ws://' + domain + '/ws/v2'
        self.client = WSClient(self.host)
        wait_time = between(0, 1)

        # print("%s" %self.client)




# class ApiUser(AbstractLocust):
# class ApiUser(AbstractLocust):
#
#     host = 'ws://' + domain + '/ws/v2'
#     print("++++++host is %s" % host)
#     min_wait = 10
#     max_wait = 1000
#     wait_time = between(0, 1)

    @task
    class task_set(TaskSet):
        def on_start(self):
            self.client.useWebCreate()
            # TODO: 设置集合点...
            # all_locusts_spawned.wait()

        @task
        def execute_long_run(self):
            self.client.execute()


# 加签
class authfunc(object):

    def _sign(self, param=None, _accessKeySecret=None):
        # create signature:
        if param is None:
            params = {}
        params = {}
        # params['signatureMethod'] = param.get('signatureMethod') if type(param.get('signatureMethod')) == type(
        #     'a') else '' if param.get('signatureMethod') else ''
        # params['signatureVersion'] = param.get('signatureVersion') if type(param.get('signatureVersion')) == type(
        #     'a') else '' if param.get('signatureVersion') else ''
        # params['accessKey'] = param.get('accessKey') if type(param.get('accessKey')) == type(
        #     'a') else '' if param.get('accessKey') else ''
        # params['timestamp'] = param.get('timestamp') if type(param.get('timestamp')) == type('a') else '' if param.get(
        #     'timestamp') else ''

        params['signatureMethod'] = param.get('signatureMethod') if type(param.get('signatureMethod')) == type(
            'a') else '' if param.get('signatureMethod') else ''
        params['signatureVersion'] = param.get('signatureVersion') if type(param.get('signatureVersion')) == type(
            'a') else '' if param.get('signatureVersion') else ''
        params['accessKey'] = param.get('accessKey') if type(param.get('accessKey')) == type(
            'a') else '' if param.get('accessKey') else ''
        params['timestamp'] = param.get('timestamp') if type(param.get('timestamp')) == type('a') else '' if param.get(
            'timestamp') else ''

        # print(params)
        # sort by key:
        keys = sorted(params.keys())
        _host = domain
        path = '/ws/v2'
        # build query string like: a=1&b=%20&c=:
        qs = '&'.join(['%s=%s' % (key, self._encode(params[key])) for key in keys])
        # build payload:
        payload = '%s\n%s\n%s\n%s' % ('GET', _host, path, qs)
        # print(payload)
        # print('payload:\n%s' % payload)
        dig = hmac.new(_accessKeySecret, msg=payload.encode('utf-8'), digestmod=hashlib.sha256).digest()

        return base64.b64encode(dig).decode()

    def _encode(self, s):
        # return urllib.pathname2url(s)
        return parse.quote(s, safe='')

    def _utc(self):
        return datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S')

    def _auth(self, auth):
        authenticaton_data = auth[1]
        _accessKeySecret = auth[0]

        authenticaton_data['signature'] = self._sign(authenticaton_data, _accessKeySecret)

        # print(authenticaton_data)

        return authenticaton_data

 

然后通过locust命令执行

locust -f xx.py  --no-web -c 2 -r 1 -t 1m

今天的分享到此就结束了,感谢您的阅读,如果确实帮到您,您可以动动手指转发给其他人。

上一篇

已是最后文章

下一篇

已是最新文章

发表回复