635 lines
14 KiB
Go
635 lines
14 KiB
Go
package shell
|
|
|
|
import (
|
|
"fmt"
|
|
"path/filepath"
|
|
"strings"
|
|
)
|
|
|
|
// executeIf handles: if COND; then BODY; [elif COND; then BODY;]* [else BODY;] fi
|
|
func (s *Shell) executeIf(block string) error {
|
|
stmts := splitStatements(block)
|
|
|
|
type branch struct {
|
|
cond []string
|
|
body []string
|
|
}
|
|
|
|
var branches []branch
|
|
var elseBody []string
|
|
|
|
phase := "if_cond"
|
|
var curCond []string
|
|
var curBody []string
|
|
|
|
for _, stmt := range stmts {
|
|
w := firstWord(stmt)
|
|
rest := afterWord(stmt)
|
|
|
|
switch {
|
|
case w == "if" && phase == "if_cond":
|
|
if rest != "" {
|
|
curCond = append(curCond, rest)
|
|
}
|
|
case w == "then":
|
|
if rest != "" {
|
|
curBody = append(curBody, rest)
|
|
}
|
|
phase = "body"
|
|
case w == "elif":
|
|
branches = append(branches, branch{curCond, curBody})
|
|
curCond = nil
|
|
curBody = nil
|
|
if rest != "" {
|
|
curCond = append(curCond, rest)
|
|
}
|
|
phase = "elif_cond"
|
|
case w == "else":
|
|
branches = append(branches, branch{curCond, curBody})
|
|
curCond = nil
|
|
curBody = nil
|
|
if rest != "" {
|
|
elseBody = append(elseBody, rest)
|
|
}
|
|
phase = "else"
|
|
case w == "fi":
|
|
switch phase {
|
|
case "body":
|
|
branches = append(branches, branch{curCond, curBody})
|
|
case "else":
|
|
if rest != "" {
|
|
elseBody = append(elseBody, rest)
|
|
}
|
|
}
|
|
default:
|
|
switch phase {
|
|
case "if_cond", "elif_cond":
|
|
curCond = append(curCond, stmt)
|
|
case "body":
|
|
curBody = append(curBody, stmt)
|
|
case "else":
|
|
elseBody = append(elseBody, stmt)
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, b := range branches {
|
|
cond := strings.Join(b.cond, "\n")
|
|
s.Execute(cond) //nolint — we only care about $?
|
|
if s.vars["?"] == "0" {
|
|
return s.Execute(strings.Join(b.body, "\n"))
|
|
}
|
|
}
|
|
if len(elseBody) > 0 {
|
|
return s.Execute(strings.Join(elseBody, "\n"))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// executeCase handles: case WORD in PAT1) BODY1 ;; PAT2) BODY2 ;; esac
|
|
// splitStatements emits ";;" as a separate token, so we can parse arms directly.
|
|
func (s *Shell) executeCase(block string) error {
|
|
stmts := splitStatements(block)
|
|
if len(stmts) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// First stmt: "case WORD in"
|
|
caseHeader := stmts[0]
|
|
caseRest := strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(caseHeader), "case"))
|
|
|
|
var word string
|
|
var startIdx int
|
|
|
|
// Find trailing " in" to separate word from "in"
|
|
inIdx := strings.LastIndex(caseRest, " in")
|
|
if inIdx >= 0 && strings.TrimSpace(caseRest[inIdx+3:]) == "" {
|
|
word = s.expandWord(strings.TrimSpace(caseRest[:inIdx]))
|
|
startIdx = 1
|
|
} else {
|
|
word = s.expandWord(caseRest)
|
|
startIdx = 1
|
|
if startIdx < len(stmts) && strings.TrimSpace(stmts[startIdx]) == "in" {
|
|
startIdx++
|
|
}
|
|
}
|
|
|
|
// Parse arms from remaining stmts
|
|
// stmts now include ";;" as explicit tokens
|
|
type arm struct {
|
|
patterns []string
|
|
body []string
|
|
}
|
|
var arms []arm
|
|
var curArm *arm
|
|
|
|
for _, stmt := range stmts[startIdx:] {
|
|
stmt = strings.TrimSpace(stmt)
|
|
if stmt == "" {
|
|
continue
|
|
}
|
|
if stmt == "esac" {
|
|
break
|
|
}
|
|
if stmt == ";;" || stmt == ";&" || stmt == ";;&" {
|
|
curArm = nil
|
|
continue
|
|
}
|
|
|
|
if curArm == nil {
|
|
// Expect a pattern: PAT) [body]
|
|
parenIdx := findCasePatternEnd(stmt)
|
|
if parenIdx < 0 {
|
|
continue
|
|
}
|
|
patStr := strings.TrimSpace(stmt[:parenIdx])
|
|
bodyStr := strings.TrimSpace(stmt[parenIdx+1:])
|
|
|
|
rawPats := strings.Split(patStr, "|")
|
|
var pats []string
|
|
for _, p := range rawPats {
|
|
p = strings.TrimSpace(p)
|
|
if p != "" {
|
|
pats = append(pats, p)
|
|
}
|
|
}
|
|
arms = append(arms, arm{patterns: pats})
|
|
curArm = &arms[len(arms)-1]
|
|
if bodyStr != "" {
|
|
curArm.body = append(curArm.body, bodyStr)
|
|
}
|
|
} else {
|
|
curArm.body = append(curArm.body, stmt)
|
|
}
|
|
}
|
|
|
|
// Execute matching arm
|
|
for _, a := range arms {
|
|
for _, pat := range a.patterns {
|
|
expandedPat := s.expandWord(pat)
|
|
matched := false
|
|
if expandedPat == "*" {
|
|
matched = true
|
|
} else {
|
|
if m, err := filepath.Match(expandedPat, word); err == nil {
|
|
matched = m
|
|
} else {
|
|
matched = expandedPat == word
|
|
}
|
|
}
|
|
if matched {
|
|
body := strings.Join(a.body, "\n")
|
|
return s.Execute(body)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// findCasePatternEnd finds the ) that ends the pattern in a case arm.
|
|
// Handles quoted strings.
|
|
func findCasePatternEnd(chunk string) int {
|
|
inSingle := false
|
|
inDouble := false
|
|
for i := 0; i < len(chunk); i++ {
|
|
c := chunk[i]
|
|
switch {
|
|
case c == '\'' && !inDouble:
|
|
inSingle = !inSingle
|
|
case c == '"' && !inSingle:
|
|
inDouble = !inDouble
|
|
case c == ')' && !inSingle && !inDouble:
|
|
return i
|
|
}
|
|
}
|
|
return -1
|
|
}
|
|
|
|
// executeFor handles: for VAR in WORDS; do BODY; done
|
|
// (also: for VAR; do BODY; done — iterates positional params)
|
|
func (s *Shell) executeFor(block string) error {
|
|
stmts := splitStatements(block)
|
|
if len(stmts) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Parse "for VAR in WORDS"
|
|
header := stmts[0]
|
|
// Use tokenize so array expansion works in "for x in ${arr[@]}"
|
|
fields := s.tokenize(header)
|
|
if len(fields) < 2 {
|
|
return fmt.Errorf("for: bad syntax")
|
|
}
|
|
varName := fields[1]
|
|
|
|
var items []string
|
|
inIdx := -1
|
|
for i, w := range fields {
|
|
if w == "in" {
|
|
inIdx = i
|
|
break
|
|
}
|
|
}
|
|
if inIdx >= 0 {
|
|
// Items are already expanded by tokenize
|
|
items = fields[inIdx+1:]
|
|
} else {
|
|
// for var; do ... → iterate positional params
|
|
items = s.args
|
|
}
|
|
|
|
// Collect body between "do" and "done"
|
|
var bodyStmts []string
|
|
inBody := false
|
|
for _, stmt := range stmts[1:] {
|
|
w := firstWord(stmt)
|
|
if !inBody {
|
|
if w == "do" {
|
|
inBody = true
|
|
if rest := afterWord(stmt); rest != "" {
|
|
bodyStmts = append(bodyStmts, rest)
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
if w == "done" {
|
|
break
|
|
}
|
|
bodyStmts = append(bodyStmts, stmt)
|
|
}
|
|
|
|
body := strings.Join(bodyStmts, "\n")
|
|
|
|
for _, item := range items {
|
|
s.vars[varName] = item
|
|
if err := s.Execute(body); err != nil {
|
|
if be, ok := err.(breakErr); ok {
|
|
if be.n <= 1 {
|
|
break
|
|
}
|
|
return breakErr{be.n - 1}
|
|
}
|
|
if ce, ok := err.(continueErr); ok {
|
|
if ce.n <= 1 {
|
|
continue
|
|
}
|
|
return continueErr{ce.n - 1}
|
|
}
|
|
if _, ok := err.(returnErr); ok {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// executeForC handles C-style: for ((init; cond; incr)); do BODY; done
|
|
func (s *Shell) executeForC(block string) error {
|
|
stmts := splitStatements(block)
|
|
if len(stmts) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Extract the ((...)) header
|
|
header := stmts[0]
|
|
// Find "for" keyword and then "(("
|
|
trimmed := strings.TrimSpace(header)
|
|
// Strip "for" keyword
|
|
trimmed = strings.TrimSpace(trimmed[3:]) // skip "for"
|
|
// Expect "(("
|
|
if !strings.HasPrefix(trimmed, "((") {
|
|
return fmt.Errorf("for: bad C-style syntax")
|
|
}
|
|
trimmed = trimmed[2:] // skip "(("
|
|
// Find closing "))"
|
|
endIdx := strings.Index(trimmed, "))")
|
|
if endIdx < 0 {
|
|
return fmt.Errorf("for: missing '))'")
|
|
}
|
|
inner := trimmed[:endIdx]
|
|
|
|
// Split on ; to get init, cond, incr
|
|
parts := strings.SplitN(inner, ";", 3)
|
|
init := ""
|
|
cond := ""
|
|
incr := ""
|
|
if len(parts) >= 1 {
|
|
init = strings.TrimSpace(parts[0])
|
|
}
|
|
if len(parts) >= 2 {
|
|
cond = strings.TrimSpace(parts[1])
|
|
}
|
|
if len(parts) >= 3 {
|
|
incr = strings.TrimSpace(parts[2])
|
|
}
|
|
|
|
// Execute init as arithmetic assignment
|
|
if init != "" {
|
|
s.execArithAssign(init)
|
|
}
|
|
|
|
// Collect body between "do" and "done"
|
|
var bodyStmts []string
|
|
inBody := false
|
|
for _, stmt := range stmts[1:] {
|
|
w := firstWord(stmt)
|
|
if !inBody {
|
|
if w == "do" {
|
|
inBody = true
|
|
if rest := afterWord(stmt); rest != "" {
|
|
bodyStmts = append(bodyStmts, rest)
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
if w == "done" {
|
|
break
|
|
}
|
|
bodyStmts = append(bodyStmts, stmt)
|
|
}
|
|
|
|
body := strings.Join(bodyStmts, "\n")
|
|
|
|
for {
|
|
// Evaluate condition
|
|
if cond != "" {
|
|
if s.evalArith(cond) == 0 {
|
|
break
|
|
}
|
|
}
|
|
|
|
if err := s.Execute(body); err != nil {
|
|
if be, ok := err.(breakErr); ok {
|
|
if be.n <= 1 {
|
|
break
|
|
}
|
|
return breakErr{be.n - 1}
|
|
}
|
|
if ce, ok := err.(continueErr); ok {
|
|
if ce.n <= 1 {
|
|
// continue — execute incr then re-check cond
|
|
if incr != "" {
|
|
s.execArithAssign(incr)
|
|
}
|
|
continue
|
|
}
|
|
return continueErr{ce.n - 1}
|
|
}
|
|
if _, ok := err.(returnErr); ok {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Execute increment
|
|
if incr != "" {
|
|
s.execArithAssign(incr)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// execArithAssign handles arithmetic assignment expressions like i=0, i++, i+=1, ((i++))
|
|
func (s *Shell) execArithAssign(expr string) {
|
|
expr = strings.TrimSpace(expr)
|
|
if expr == "" {
|
|
return
|
|
}
|
|
|
|
// Handle i++ and i--
|
|
if strings.HasSuffix(expr, "++") {
|
|
varName := strings.TrimSpace(expr[:len(expr)-2])
|
|
if isValidIdentifier(varName) {
|
|
n := s.evalArith(varName)
|
|
s.vars[varName] = fmt.Sprintf("%d", n+1)
|
|
return
|
|
}
|
|
}
|
|
if strings.HasSuffix(expr, "--") {
|
|
varName := strings.TrimSpace(expr[:len(expr)-2])
|
|
if isValidIdentifier(varName) {
|
|
n := s.evalArith(varName)
|
|
s.vars[varName] = fmt.Sprintf("%d", n-1)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Handle i+=N
|
|
if idx := strings.Index(expr, "+="); idx > 0 {
|
|
varName := strings.TrimSpace(expr[:idx])
|
|
if isValidIdentifier(varName) {
|
|
delta := s.evalArith(expr[idx+2:])
|
|
n := s.evalArith(varName)
|
|
s.vars[varName] = fmt.Sprintf("%d", n+delta)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Handle i-=N
|
|
if idx := strings.Index(expr, "-="); idx > 0 {
|
|
varName := strings.TrimSpace(expr[:idx])
|
|
if isValidIdentifier(varName) {
|
|
delta := s.evalArith(expr[idx+2:])
|
|
n := s.evalArith(varName)
|
|
s.vars[varName] = fmt.Sprintf("%d", n-delta)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Handle i=expr
|
|
if idx := strings.Index(expr, "="); idx > 0 {
|
|
varName := strings.TrimSpace(expr[:idx])
|
|
if isValidIdentifier(varName) {
|
|
n := s.evalArith(expr[idx+1:])
|
|
s.vars[varName] = fmt.Sprintf("%d", n)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// executeWhileUntil handles while/until loops.
|
|
func (s *Shell) executeWhileUntil(block string, isUntil bool) error {
|
|
stmts := splitStatements(block)
|
|
if len(stmts) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Parse condition (everything from "while/until COND" up to "do")
|
|
var condStmts []string
|
|
if rest := afterWord(stmts[0]); rest != "" {
|
|
condStmts = append(condStmts, rest)
|
|
}
|
|
|
|
var bodyStmts []string
|
|
inBody := false
|
|
for _, stmt := range stmts[1:] {
|
|
w := firstWord(stmt)
|
|
if !inBody {
|
|
if w == "do" {
|
|
inBody = true
|
|
if rest := afterWord(stmt); rest != "" {
|
|
bodyStmts = append(bodyStmts, rest)
|
|
}
|
|
} else {
|
|
condStmts = append(condStmts, stmt)
|
|
}
|
|
continue
|
|
}
|
|
if w == "done" {
|
|
break
|
|
}
|
|
bodyStmts = append(bodyStmts, stmt)
|
|
}
|
|
|
|
cond := strings.Join(condStmts, "\n")
|
|
body := strings.Join(bodyStmts, "\n")
|
|
|
|
for {
|
|
s.Execute(cond) //nolint
|
|
condOk := s.vars["?"] == "0"
|
|
|
|
if (isUntil && condOk) || (!isUntil && !condOk) {
|
|
break
|
|
}
|
|
|
|
if err := s.Execute(body); err != nil {
|
|
if be, ok := err.(breakErr); ok {
|
|
if be.n <= 1 {
|
|
break
|
|
}
|
|
return breakErr{be.n - 1}
|
|
}
|
|
if ce, ok := err.(continueErr); ok {
|
|
if ce.n <= 1 {
|
|
continue
|
|
}
|
|
return continueErr{ce.n - 1}
|
|
}
|
|
if _, ok := err.(returnErr); ok {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// defineFunction parses and registers a shell function definition.
|
|
func (s *Shell) defineFunction(block string) error {
|
|
stmts := splitStatements(block)
|
|
if len(stmts) == 0 {
|
|
return fmt.Errorf("syntax error: empty function")
|
|
}
|
|
|
|
first := stmts[0]
|
|
var name string
|
|
|
|
if strings.HasPrefix(first, "function ") {
|
|
rest := strings.TrimPrefix(first, "function ")
|
|
rest = strings.TrimSpace(rest)
|
|
// Strip trailing () and {
|
|
rest = strings.TrimSuffix(strings.TrimSpace(rest), "{")
|
|
rest = strings.TrimSuffix(strings.TrimSpace(rest), "()")
|
|
name = strings.TrimSpace(rest)
|
|
} else {
|
|
parenIdx := strings.Index(first, "(")
|
|
if parenIdx < 0 {
|
|
return fmt.Errorf("syntax error: bad function definition")
|
|
}
|
|
name = strings.TrimSpace(first[:parenIdx])
|
|
}
|
|
|
|
if !isValidIdentifier(name) {
|
|
return fmt.Errorf("syntax error: invalid function name %q", name)
|
|
}
|
|
|
|
// Find the opening { in the block — it may be on the same line as the name
|
|
// or on a following stmt. Everything after { (up to closing }) is the body.
|
|
var bodyStmts []string
|
|
inBody := false
|
|
|
|
for _, stmt := range stmts {
|
|
trimmed := strings.TrimSpace(stmt)
|
|
|
|
if !inBody {
|
|
// Look for { in this stmt
|
|
braceIdx := strings.Index(trimmed, "{")
|
|
if braceIdx >= 0 {
|
|
inBody = true
|
|
rest := strings.TrimSpace(trimmed[braceIdx+1:])
|
|
if rest != "" && rest != "}" {
|
|
bodyStmts = append(bodyStmts, rest)
|
|
}
|
|
// Check if } is also on this line (single-liner like name() { cmd; })
|
|
if strings.HasSuffix(trimmed, "}") && braceIdx < len(trimmed)-1 {
|
|
// body is between { and }
|
|
inner := strings.TrimSpace(trimmed[braceIdx+1 : len(trimmed)-1])
|
|
bodyStmts = nil
|
|
if inner != "" {
|
|
bodyStmts = append(bodyStmts, inner)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
|
|
if trimmed == "}" {
|
|
break
|
|
}
|
|
bodyStmts = append(bodyStmts, stmt)
|
|
}
|
|
|
|
funcBody := strings.Join(bodyStmts, "\n")
|
|
s.funcs[name] = funcBody
|
|
|
|
funcName := name
|
|
s.builtins[funcName] = func(args []string) error {
|
|
return s.callFunction(funcName, args)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Shell) callFunction(name string, args []string) error {
|
|
body, ok := s.funcs[name]
|
|
if !ok {
|
|
return fmt.Errorf("%s: function not found", name)
|
|
}
|
|
|
|
// Save positional params and exit code
|
|
oldArgs := s.args
|
|
savedPos := map[string]string{}
|
|
for k, v := range s.vars {
|
|
if k == "#" || k == "@" || k == "*" || (len(k) == 1 && k[0] >= '1' && k[0] <= '9') {
|
|
savedPos[k] = v
|
|
}
|
|
}
|
|
|
|
s.SetArgs(args)
|
|
s.vars["?"] = "0" // reset before running body
|
|
|
|
err := s.Execute(body)
|
|
|
|
// Capture the function's exit code BEFORE restoring params (which might not include ?)
|
|
funcExitCode := s.lastExit
|
|
|
|
// Restore positional params
|
|
s.args = oldArgs
|
|
for k, v := range savedPos {
|
|
s.vars[k] = v
|
|
}
|
|
|
|
if re, ok := err.(returnErr); ok {
|
|
if re.code != 0 {
|
|
return exitCodeErr{re.code}
|
|
}
|
|
return nil
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Propagate last command's exit code from the function body
|
|
if funcExitCode != 0 {
|
|
return exitCodeErr{funcExitCode}
|
|
}
|
|
return nil
|
|
}
|