168 lines
4.6 KiB
Go
168 lines
4.6 KiB
Go
package api
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
|
|
lua "github.com/yuin/gopher-lua"
|
|
)
|
|
|
|
// TestResult holds the outcome of a single test assertion.
|
|
type TestResult struct {
|
|
Pass bool
|
|
Message string
|
|
}
|
|
|
|
// RunLuaTestScript executes a Lua script against a Response.
|
|
//
|
|
// Globals available in the script:
|
|
//
|
|
// status (number) HTTP status code, e.g. 200
|
|
// status_text (string) Full status string, e.g. "200 OK"
|
|
// body (string) Raw response body
|
|
// headers (table) Response headers, lower-cased keys
|
|
// json_body (table) Parsed JSON body, or nil if not JSON
|
|
// duration_ms (number) Round-trip duration in milliseconds
|
|
// size (number) Response body size in bytes
|
|
//
|
|
// Helper functions:
|
|
//
|
|
// pass(message) Record a passing assertion
|
|
// fail(message) Record a failing assertion (does NOT abort the script)
|
|
// assert_eq(a, b, msg) fail(msg) if a ~= b, else pass(msg)
|
|
// assert_status(code) assert_eq(status, code, "status == "..code)
|
|
// json_decode(str) Parse a JSON string → Lua table
|
|
func RunLuaTestScript(script string, resp *Response) ([]TestResult, error) {
|
|
if strings.TrimSpace(script) == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
L := lua.NewState(lua.Options{SkipOpenLibs: false})
|
|
defer L.Close()
|
|
|
|
var results []TestResult
|
|
|
|
// Inject pass / fail helpers
|
|
L.SetGlobal("pass", L.NewFunction(func(L *lua.LState) int {
|
|
msg := L.OptString(1, "")
|
|
results = append(results, TestResult{Pass: true, Message: msg})
|
|
return 0
|
|
}))
|
|
L.SetGlobal("fail", L.NewFunction(func(L *lua.LState) int {
|
|
msg := L.OptString(1, "")
|
|
results = append(results, TestResult{Pass: false, Message: msg})
|
|
return 0
|
|
}))
|
|
L.SetGlobal("assert_eq", L.NewFunction(func(L *lua.LState) int {
|
|
a := L.Get(1)
|
|
b := L.Get(2)
|
|
msg := L.OptString(3, fmt.Sprintf("%v == %v", a, b))
|
|
if lua.LVCanConvToString(a) && lua.LVCanConvToString(b) &&
|
|
lua.LVAsString(a) == lua.LVAsString(b) {
|
|
results = append(results, TestResult{Pass: true, Message: msg})
|
|
} else {
|
|
results = append(results, TestResult{
|
|
Pass: false,
|
|
Message: fmt.Sprintf("%s (got %v, expected %v)", msg, a, b),
|
|
})
|
|
}
|
|
return 0
|
|
}))
|
|
L.SetGlobal("assert_status", L.NewFunction(func(L *lua.LState) int {
|
|
expected := L.CheckInt(1)
|
|
msg := fmt.Sprintf("status == %d", expected)
|
|
if resp.StatusCode == expected {
|
|
results = append(results, TestResult{Pass: true, Message: msg})
|
|
} else {
|
|
results = append(results, TestResult{
|
|
Pass: false,
|
|
Message: fmt.Sprintf("%s (got %d)", msg, resp.StatusCode),
|
|
})
|
|
}
|
|
return 0
|
|
}))
|
|
|
|
// json_decode helper
|
|
L.SetGlobal("json_decode", L.NewFunction(func(L *lua.LState) int {
|
|
s := L.CheckString(1)
|
|
var v interface{}
|
|
if err := json.Unmarshal([]byte(s), &v); err != nil {
|
|
L.Push(lua.LNil)
|
|
L.Push(lua.LString(err.Error()))
|
|
return 2
|
|
}
|
|
L.Push(goToLua(L, v))
|
|
return 1
|
|
}))
|
|
|
|
// Set response globals
|
|
L.SetGlobal("status", lua.LNumber(resp.StatusCode))
|
|
L.SetGlobal("status_text", lua.LString(resp.Status))
|
|
L.SetGlobal("body", lua.LString(resp.Body))
|
|
L.SetGlobal("duration_ms", lua.LNumber(resp.Duration.Milliseconds()))
|
|
L.SetGlobal("size", lua.LNumber(resp.Size))
|
|
|
|
// headers table (lower-cased keys)
|
|
headersTbl := L.NewTable()
|
|
for k, vals := range resp.Headers {
|
|
if len(vals) > 0 {
|
|
L.SetField(headersTbl, strings.ToLower(k), lua.LString(vals[0]))
|
|
}
|
|
}
|
|
L.SetGlobal("headers", headersTbl)
|
|
|
|
// json_body: parse resp.Body if JSON
|
|
var jsonBody interface{}
|
|
if resp.IsJSON {
|
|
if err := json.Unmarshal([]byte(resp.Body), &jsonBody); err == nil {
|
|
L.SetGlobal("json_body", goToLua(L, jsonBody))
|
|
} else {
|
|
L.SetGlobal("json_body", lua.LNil)
|
|
}
|
|
} else {
|
|
L.SetGlobal("json_body", lua.LNil)
|
|
}
|
|
|
|
if err := L.DoString(script); err != nil {
|
|
// Include any results collected before the error
|
|
errResult := TestResult{
|
|
Pass: false,
|
|
Message: "Lua error: " + err.Error(),
|
|
}
|
|
results = append(results, errResult)
|
|
return results, nil // return results, not the error, so UI can display them
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// goToLua converts a Go value (from json.Unmarshal) into a lua.LValue.
|
|
func goToLua(L *lua.LState, v interface{}) lua.LValue {
|
|
if v == nil {
|
|
return lua.LNil
|
|
}
|
|
switch val := v.(type) {
|
|
case bool:
|
|
return lua.LBool(val)
|
|
case float64:
|
|
return lua.LNumber(val)
|
|
case string:
|
|
return lua.LString(val)
|
|
case []interface{}:
|
|
tbl := L.NewTable()
|
|
for i, item := range val {
|
|
tbl.RawSetInt(i+1, goToLua(L, item))
|
|
}
|
|
return tbl
|
|
case map[string]interface{}:
|
|
tbl := L.NewTable()
|
|
for k, item := range val {
|
|
L.SetField(tbl, k, goToLua(L, item))
|
|
}
|
|
return tbl
|
|
default:
|
|
return lua.LString(fmt.Sprintf("%v", val))
|
|
}
|
|
}
|