Add control flow, I/O redirection, functions, coreutils, history/completion

- I/O redirection: >, >>, <, 2>, 2>&1, &>
- Job control: background & operator
- if/elif/else/fi, for/do/done, while/until loops
- Shell functions with local/declare, positional param save/restore
- Exit code propagation via exitCodeErr sentinel
- Arithmetic expansion $((expr)) with bare variable names
- Command substitution $(cmd) with pipeline support
- Glob expansion, tilde expansion, ${VAR:-default} and other forms
- Tab completion and command history via chzyer/readline
- Inline comment stripping (# outside quotes)
- Builtins: test/[, read, printf, tr, sed, cut, tail, tee, xargs,
  basename, dirname, date, sleep, uniq, sort, wc, head, grep, find,
  true, false, break, continue, return, shift, set, unset, export,
  declare/local, source, alias, jobs, command, which, env
- Bug fixes: tokenizer parenDepth double-count for $((,
  splitPipe not paren-aware (broke pipelines in $()),
  local/declare TrimLeft stripping valid var name chars,
  parseBlocks missing nested keywords after do/then/else

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Cametendo
2026-05-26 12:50:06 +02:00
parent eba49c46bc
commit 8c6a2ab4c2
8 changed files with 3669 additions and 869 deletions

View File

@@ -1,10 +1,13 @@
package bash
import (
"bufio"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/chzyer/readline"
"github.com/cametendo/bash-for-windows/internal/shell"
)
@@ -12,51 +15,109 @@ func Run() error {
args := os.Args[1:]
if len(args) > 0 {
if args[0] == "-c" && len(args) > 1 {
// Execute a command string
return runCommand(strings.Join(args[1:], " "))
switch args[0] {
case "-c":
if len(args) < 2 {
return fmt.Errorf("-c: option requires an argument")
}
sh := shell.New()
if len(args) > 2 {
sh.SetArgs(args[2:])
}
return sh.Execute(strings.Join(args[1:], " "))
case "--version":
fmt.Println("bash-for-windows 2.0.0 (Go-based)")
fmt.Println("Provides bash-compatible shell for Windows.")
return nil
default:
return runFile(args[0], args[1:])
}
// Run a script file
return runFile(args[0])
}
return interactive()
}
func runCommand(cmd string) error {
sh := shell.New()
return sh.Execute(cmd)
}
func runFile(path string) error {
func runFile(path string, args []string) error {
data, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("%s: %v", path, err)
}
sh := shell.New()
sh.SetArgs(append([]string{path}, args...))
sh.SetVar("0", path)
return sh.Execute(string(data))
}
func historyFile() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
// On Windows this lands in %USERPROFILE%\.bash_history
return filepath.Join(home, ".bash_history")
}
func interactive() error {
sh := shell.New()
reader := bufio.NewReader(os.Stdin)
fmt.Println("bash-for-windows v1.0.0")
// Build completer
completer := readline.NewPrefixCompleter(
readline.PcItemDynamic(func(line string) []string {
return dynamicComplete(sh, line)
}),
)
rl, err := readline.NewEx(&readline.Config{
HistoryFile: historyFile(),
AutoComplete: completer,
InterruptPrompt: "^C",
EOFPrompt: "exit",
HistorySearchFold: true,
FuncFilterInputRune: filterInput,
})
if err != nil {
// Fall back to dumb interactive mode
return interactiveDumb(sh)
}
defer rl.Close()
fmt.Fprintln(os.Stdout, "bash-for-windows v2.0.0 (type 'exit' or Ctrl+D to quit)")
var multiLine strings.Builder
for {
fmt.Print("bash$ ")
input, err := reader.ReadString('\n')
if err != nil {
break
prompt := buildPrompt(sh)
if multiLine.Len() > 0 {
prompt = "> "
}
rl.SetPrompt(prompt)
line, err := rl.Readline()
if err != nil {
if err.Error() == "Interrupt" {
// Ctrl+C: abort current multi-line input
multiLine.Reset()
fmt.Fprintln(os.Stdout)
continue
}
break // EOF
}
if multiLine.Len() > 0 {
multiLine.WriteString("\n")
}
multiLine.WriteString(line)
input := multiLine.String()
if shell.IsIncomplete(input) {
continue // wait for more input
}
multiLine.Reset()
input = strings.TrimSpace(input)
if input == "" {
continue
}
if input == "exit" {
break
}
if err := sh.Execute(input); err != nil {
fmt.Fprintf(os.Stderr, "bash: %v\n", err)
@@ -64,3 +125,137 @@ func interactive() error {
}
return nil
}
// interactiveDumb is a fallback that doesn't need readline.
func interactiveDumb(sh *shell.Shell) error {
fmt.Fprintln(os.Stdout, "bash-for-windows v2.0.0")
var multiLine strings.Builder
buf := make([]byte, 4096)
for {
prompt := buildPrompt(sh)
if multiLine.Len() > 0 {
prompt = "> "
}
fmt.Fprint(os.Stdout, prompt)
n, err := os.Stdin.Read(buf)
if n > 0 {
chunk := string(buf[:n])
if multiLine.Len() > 0 {
multiLine.WriteString("\n")
}
multiLine.WriteString(strings.TrimRight(chunk, "\r\n"))
input := multiLine.String()
if shell.IsIncomplete(input) {
continue
}
multiLine.Reset()
input = strings.TrimSpace(input)
if input == "" {
continue
}
if execErr := sh.Execute(input); execErr != nil {
fmt.Fprintf(os.Stderr, "bash: %v\n", execErr)
}
}
if err != nil {
break
}
}
return nil
}
func buildPrompt(sh *shell.Shell) string {
pwd, _ := os.Getwd()
home, _ := os.UserHomeDir()
if home != "" && strings.HasPrefix(pwd, home) {
pwd = "~" + pwd[len(home):]
}
// Show exit code in prompt if non-zero
exitCode := sh.GetVar("?")
suffix := "$ "
if exitCode != "0" && exitCode != "" {
suffix = "[" + exitCode + "]$ "
}
return pwd + suffix
}
// dynamicComplete provides tab completion for commands and paths.
func dynamicComplete(sh *shell.Shell, line string) []string {
line = strings.TrimLeft(line, " \t")
var completions []string
// Check if we're completing the first word (command) or an argument (path)
words := strings.Fields(line)
completingCommand := len(words) == 0 || (len(words) == 1 && !strings.HasSuffix(line, " "))
if completingCommand {
prefix := ""
if len(words) == 1 {
prefix = words[0]
}
// Builtin/function names — access via the shell instance
// (we can't iterate unexported fields, so use a public method)
for _, name := range sh.BuiltinNames() {
if strings.HasPrefix(name, prefix) {
completions = append(completions, name)
}
}
// PATH executables
for _, dir := range filepath.SplitList(os.Getenv("PATH")) {
entries, err := os.ReadDir(dir)
if err != nil {
continue
}
for _, e := range entries {
name := e.Name()
// Strip .exe on completion display
name = strings.TrimSuffix(name, ".exe")
if strings.HasPrefix(name, prefix) {
completions = append(completions, name)
}
}
}
} else {
// Path completion
prefix := ""
if len(words) > 0 {
prefix = words[len(words)-1]
if strings.HasSuffix(line, " ") {
prefix = ""
}
}
dir := filepath.Dir(prefix)
base := filepath.Base(prefix)
if prefix == "" || strings.HasSuffix(prefix, "/") || strings.HasSuffix(prefix, "\\") {
dir = prefix
base = ""
}
entries, err := os.ReadDir(dir)
if err == nil {
for _, e := range entries {
name := e.Name()
if strings.HasPrefix(name, base) {
p := filepath.Join(dir, name)
if e.IsDir() {
p += "/"
}
completions = append(completions, p)
}
}
}
}
return completions
}
func filterInput(r rune) (rune, bool) {
// Block Ctrl+Z (26) — on Windows this would suspend; we handle it gracefully
if r == 26 {
return r, false
}
return r, true
}

5
go.mod
View File

@@ -1,3 +1,8 @@
module github.com/cametendo/bash-for-windows
go 1.26.2
require (
github.com/chzyer/readline v1.5.1 // indirect
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 // indirect
)

6
go.sum Normal file
View File

@@ -0,0 +1,6 @@
github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ=
github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI=
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 h1:y/woIyUBFbpQGKS0u1aHF/40WUDnek3fPOyD08H5Vng=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

File diff suppressed because it is too large Load Diff

352
internal/shell/control.go Normal file
View File

@@ -0,0 +1,352 @@
package shell
import (
"fmt"
"strings"
)
// executeIf handles: if COND; then BODY; [elif COND; then BODY;]* [else BODY;] fi
func (s *Shell) executeIf(block string) error {
stmts := splitStatements(block)
type branch struct {
cond []string
body []string
}
var branches []branch
var elseBody []string
phase := "if_cond"
var curCond []string
var curBody []string
for _, stmt := range stmts {
w := firstWord(stmt)
rest := afterWord(stmt)
switch {
case w == "if" && phase == "if_cond":
if rest != "" {
curCond = append(curCond, rest)
}
case w == "then":
if rest != "" {
curBody = append(curBody, rest)
}
phase = "body"
case w == "elif":
branches = append(branches, branch{curCond, curBody})
curCond = nil
curBody = nil
if rest != "" {
curCond = append(curCond, rest)
}
phase = "elif_cond"
case w == "else":
branches = append(branches, branch{curCond, curBody})
curCond = nil
curBody = nil
if rest != "" {
elseBody = append(elseBody, rest)
}
phase = "else"
case w == "fi":
switch phase {
case "body":
branches = append(branches, branch{curCond, curBody})
case "else":
if rest != "" {
elseBody = append(elseBody, rest)
}
}
default:
switch phase {
case "if_cond", "elif_cond":
curCond = append(curCond, stmt)
case "body":
curBody = append(curBody, stmt)
case "else":
elseBody = append(elseBody, stmt)
}
}
}
for _, b := range branches {
cond := strings.Join(b.cond, "\n")
s.Execute(cond) //nolint — we only care about $?
if s.vars["?"] == "0" {
return s.Execute(strings.Join(b.body, "\n"))
}
}
if len(elseBody) > 0 {
return s.Execute(strings.Join(elseBody, "\n"))
}
return nil
}
// executeFor handles: for VAR in WORDS; do BODY; done
// (also: for VAR; do BODY; done — iterates positional params)
func (s *Shell) executeFor(block string) error {
stmts := splitStatements(block)
if len(stmts) == 0 {
return nil
}
// Parse "for VAR in WORDS"
header := stmts[0]
fields := strings.Fields(header)
if len(fields) < 2 {
return fmt.Errorf("for: bad syntax")
}
varName := fields[1]
var items []string
inIdx := -1
for i, w := range fields {
if w == "in" {
inIdx = i
break
}
}
if inIdx >= 0 {
for _, raw := range fields[inIdx+1:] {
expanded := s.expandWord(raw)
items = append(items, s.expandGlob(expanded)...)
}
} else {
// for var; do ... → iterate positional params
items = s.args
}
// Collect body between "do" and "done"
var bodyStmts []string
inBody := false
for _, stmt := range stmts[1:] {
w := firstWord(stmt)
if !inBody {
if w == "do" {
inBody = true
if rest := afterWord(stmt); rest != "" {
bodyStmts = append(bodyStmts, rest)
}
}
continue
}
if w == "done" {
break
}
bodyStmts = append(bodyStmts, stmt)
}
body := strings.Join(bodyStmts, "\n")
for _, item := range items {
s.vars[varName] = item
if err := s.Execute(body); err != nil {
if be, ok := err.(breakErr); ok {
if be.n <= 1 {
break
}
return breakErr{be.n - 1}
}
if ce, ok := err.(continueErr); ok {
if ce.n <= 1 {
continue
}
return continueErr{ce.n - 1}
}
if _, ok := err.(returnErr); ok {
return err
}
}
}
return nil
}
// executeWhileUntil handles while/until loops.
func (s *Shell) executeWhileUntil(block string, isUntil bool) error {
stmts := splitStatements(block)
if len(stmts) == 0 {
return nil
}
// Parse condition (everything from "while/until COND" up to "do")
var condStmts []string
if rest := afterWord(stmts[0]); rest != "" {
condStmts = append(condStmts, rest)
}
var bodyStmts []string
inBody := false
for _, stmt := range stmts[1:] {
w := firstWord(stmt)
if !inBody {
if w == "do" {
inBody = true
if rest := afterWord(stmt); rest != "" {
bodyStmts = append(bodyStmts, rest)
}
} else {
condStmts = append(condStmts, stmt)
}
continue
}
if w == "done" {
break
}
bodyStmts = append(bodyStmts, stmt)
}
cond := strings.Join(condStmts, "\n")
body := strings.Join(bodyStmts, "\n")
for {
s.Execute(cond) //nolint
condOk := s.vars["?"] == "0"
if (isUntil && condOk) || (!isUntil && !condOk) {
break
}
if err := s.Execute(body); err != nil {
if be, ok := err.(breakErr); ok {
if be.n <= 1 {
break
}
return breakErr{be.n - 1}
}
if ce, ok := err.(continueErr); ok {
if ce.n <= 1 {
continue
}
return continueErr{ce.n - 1}
}
if _, ok := err.(returnErr); ok {
return err
}
}
}
return nil
}
// defineFunction parses and registers a shell function definition.
func (s *Shell) defineFunction(block string) error {
stmts := splitStatements(block)
if len(stmts) == 0 {
return fmt.Errorf("syntax error: empty function")
}
first := stmts[0]
var name string
if strings.HasPrefix(first, "function ") {
rest := strings.TrimPrefix(first, "function ")
rest = strings.TrimSpace(rest)
// Strip trailing () and {
rest = strings.TrimSuffix(strings.TrimSpace(rest), "{")
rest = strings.TrimSuffix(strings.TrimSpace(rest), "()")
name = strings.TrimSpace(rest)
} else {
parenIdx := strings.Index(first, "(")
if parenIdx < 0 {
return fmt.Errorf("syntax error: bad function definition")
}
name = strings.TrimSpace(first[:parenIdx])
}
if !isValidIdentifier(name) {
return fmt.Errorf("syntax error: invalid function name %q", name)
}
// Find the opening { in the block — it may be on the same line as the name
// or on a following stmt. Everything after { (up to closing }) is the body.
var bodyStmts []string
inBody := false
for _, stmt := range stmts {
trimmed := strings.TrimSpace(stmt)
if !inBody {
// Look for { in this stmt
braceIdx := strings.Index(trimmed, "{")
if braceIdx >= 0 {
inBody = true
rest := strings.TrimSpace(trimmed[braceIdx+1:])
if rest != "" && rest != "}" {
bodyStmts = append(bodyStmts, rest)
}
// Check if } is also on this line (single-liner like name() { cmd; })
if strings.HasSuffix(trimmed, "}") && braceIdx < len(trimmed)-1 {
// body is between { and }
inner := strings.TrimSpace(trimmed[braceIdx+1 : len(trimmed)-1])
bodyStmts = nil
if inner != "" {
bodyStmts = append(bodyStmts, inner)
}
break
}
}
continue
}
if trimmed == "}" {
break
}
bodyStmts = append(bodyStmts, stmt)
}
funcBody := strings.Join(bodyStmts, "\n")
s.funcs[name] = funcBody
funcName := name
s.builtins[funcName] = func(args []string) error {
return s.callFunction(funcName, args)
}
return nil
}
func (s *Shell) callFunction(name string, args []string) error {
body, ok := s.funcs[name]
if !ok {
return fmt.Errorf("%s: function not found", name)
}
// Save positional params and exit code
oldArgs := s.args
savedPos := map[string]string{}
for k, v := range s.vars {
if k == "#" || k == "@" || k == "*" || (len(k) == 1 && k[0] >= '1' && k[0] <= '9') {
savedPos[k] = v
}
}
s.SetArgs(args)
s.vars["?"] = "0" // reset before running body
err := s.Execute(body)
// Capture the function's exit code BEFORE restoring params (which might not include ?)
funcExitCode := s.lastExit
// Restore positional params
s.args = oldArgs
for k, v := range savedPos {
s.vars[k] = v
}
if re, ok := err.(returnErr); ok {
if re.code != 0 {
return exitCodeErr{re.code}
}
return nil
}
if err != nil {
return err
}
// Propagate last command's exit code from the function body
if funcExitCode != 0 {
return exitCodeErr{funcExitCode}
}
return nil
}

375
internal/shell/exec.go Normal file
View File

@@ -0,0 +1,375 @@
package shell
import (
"bytes"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
)
type redirect struct {
fd int // 0=stdin 1=stdout 2=stderr -1=both stdout+stderr
mode string // ">" ">>" "<"
dest string // filename or "&1" "&2"
}
// executePipeline handles background jobs (&) and pipelines (|).
func (s *Shell) executePipeline(input string) error {
input = strings.TrimSpace(input)
if input == "" {
return nil
}
// Background job: trailing & (not &&)
bg := false
if strings.HasSuffix(input, "&") && !strings.HasSuffix(input, "&&") {
bg = true
input = strings.TrimSuffix(strings.TrimSuffix(input, "&"), " ")
}
parts := splitPipe(input)
if len(parts) == 1 {
if bg {
return s.executeCommandBg(strings.TrimSpace(parts[0]))
}
return s.executeCommand(strings.TrimSpace(parts[0]))
}
return s.doPipe(parts, bg)
}
// splitPipe splits by | but not ||, respecting quotes.
func splitPipe(input string) []string {
var parts []string
current := strings.Builder{}
inSingle := false
inDouble := false
parenDepth := 0
pendingDollar := false
for i := 0; i < len(input); i++ {
c := input[i]
switch {
case c == '\'' && !inDouble && parenDepth == 0:
inSingle = !inSingle
current.WriteByte(c)
case c == '"' && !inSingle && parenDepth == 0:
inDouble = !inDouble
current.WriteByte(c)
case c == '$' && !inSingle && !inDouble && i+1 < len(input) && input[i+1] == '(':
pendingDollar = true
current.WriteByte(c)
case c == '(' && !inSingle && !inDouble && (parenDepth > 0 || pendingDollar):
parenDepth++
pendingDollar = false
current.WriteByte(c)
case c == ')' && !inSingle && !inDouble && parenDepth > 0:
parenDepth--
current.WriteByte(c)
case c == '|' && !inSingle && !inDouble && parenDepth == 0:
if i+1 < len(input) && input[i+1] == '|' {
current.WriteByte(c) // part of ||, pass through
} else {
parts = append(parts, strings.TrimSpace(current.String()))
current.Reset()
pendingDollar = false
}
default:
pendingDollar = false
current.WriteByte(c)
}
}
if current.Len() > 0 {
parts = append(parts, strings.TrimSpace(current.String()))
}
return parts
}
// executeCommand executes a single command (no pipes, no &&/||).
func (s *Shell) executeCommand(input string) error {
input = strings.TrimSpace(input)
if input == "" {
return nil
}
tokens := s.tokenize(input)
if len(tokens) == 0 {
return nil
}
cmdArgs, redirects := extractRedirects(tokens)
if len(cmdArgs) == 0 {
// Pure redirection, e.g. "> file" creates/truncates file
return s.withRedirects(redirects, func() error { return nil })
}
cmdName := cmdArgs[0]
args := cmdArgs[1:]
// Alias expansion
if alias, ok := aliases[cmdName]; ok {
full := alias
if len(args) > 0 {
full += " " + strings.Join(args, " ")
}
return s.withRedirects(redirects, func() error {
return s.Execute(full)
})
}
// Builtin
if builtin, ok := s.builtins[cmdName]; ok {
return s.withRedirects(redirects, func() error {
return builtin(args)
})
}
// External
return s.withRedirects(redirects, func() error {
return s.executeExternal(cmdName, args)
})
}
func extractRedirects(tokens []string) ([]string, []redirect) {
var args []string
var redirects []redirect
i := 0
for i < len(tokens) {
tok := tokens[i]
switch {
// 2>&1
case tok == "2>&1":
redirects = append(redirects, redirect{2, ">", "&1"})
i++
// 1>&2
case tok == "1>&2":
redirects = append(redirects, redirect{1, ">", "&2"})
i++
// &> or &>> (both stdout+stderr)
case tok == "&>" || tok == "&>>":
if i+1 < len(tokens) {
redirects = append(redirects, redirect{-1, strings.TrimPrefix(tok, "&"), tokens[i+1]})
i += 2
} else {
i++
}
case strings.HasPrefix(tok, "&>"):
mode := ">"
dest := tok[2:]
if strings.HasPrefix(dest, ">") {
mode = ">>"
dest = dest[1:]
}
redirects = append(redirects, redirect{-1, mode, dest})
i++
// 2> 2>> 2>file
case tok == "2>" || tok == "2>>":
if i+1 < len(tokens) {
redirects = append(redirects, redirect{2, tok[1:], tokens[i+1]})
i += 2
} else {
i++
}
case strings.HasPrefix(tok, "2>>"):
redirects = append(redirects, redirect{2, ">>", tok[3:]})
i++
case strings.HasPrefix(tok, "2>"):
redirects = append(redirects, redirect{2, ">", tok[2:]})
i++
// > >>
case tok == ">" || tok == ">>":
if i+1 < len(tokens) {
redirects = append(redirects, redirect{1, tok, tokens[i+1]})
i += 2
} else {
i++
}
case strings.HasPrefix(tok, ">>") && len(tok) > 2:
redirects = append(redirects, redirect{1, ">>", tok[2:]})
i++
case strings.HasPrefix(tok, ">") && len(tok) > 1:
redirects = append(redirects, redirect{1, ">", tok[1:]})
i++
// < (stdin)
case tok == "<":
if i+1 < len(tokens) {
redirects = append(redirects, redirect{0, "<", tokens[i+1]})
i += 2
} else {
i++
}
case strings.HasPrefix(tok, "<") && len(tok) > 1 && tok[1] != '<':
redirects = append(redirects, redirect{0, "<", tok[1:]})
i++
default:
args = append(args, tok)
i++
}
}
return args, redirects
}
func (s *Shell) withRedirects(redirects []redirect, fn func() error) error {
if len(redirects) == 0 {
return fn()
}
oldIn, oldOut, oldErr := s.Stdin, s.Stdout, s.Stderr
var toClose []io.Closer
defer func() {
s.Stdin, s.Stdout, s.Stderr = oldIn, oldOut, oldErr
for _, c := range toClose {
c.Close()
}
}()
for _, r := range redirects {
switch r.mode {
case ">":
if r.dest == "&1" {
s.Stderr = s.Stdout
} else if r.dest == "&2" {
s.Stdout = s.Stderr
} else {
f, err := os.Create(r.dest)
if err != nil {
return fmt.Errorf("cannot open %s: %v", r.dest, err)
}
toClose = append(toClose, f)
if r.fd == 1 || r.fd == -1 {
s.Stdout = f
}
if r.fd == 2 || r.fd == -1 {
s.Stderr = f
}
}
case ">>":
f, err := os.OpenFile(r.dest, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return fmt.Errorf("cannot open %s: %v", r.dest, err)
}
toClose = append(toClose, f)
if r.fd == 1 || r.fd == -1 {
s.Stdout = f
}
if r.fd == 2 || r.fd == -1 {
s.Stderr = f
}
case "<":
f, err := os.Open(r.dest)
if err != nil {
return fmt.Errorf("cannot open %s: %v", r.dest, err)
}
toClose = append(toClose, f)
s.Stdin = f
}
}
return fn()
}
func (s *Shell) executeExternal(cmdName string, args []string) error {
path := findExecutable(cmdName)
if path == "" {
fmt.Fprintf(s.Stderr, "%s: command not found\n", cmdName)
return exitCodeErr{127}
}
cmd := exec.Command(path, args...)
cmd.Stdin = s.Stdin
cmd.Stdout = s.Stdout
cmd.Stderr = s.Stderr
err := cmd.Run()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
return exitCodeErr{exitErr.ExitCode()}
}
}
return err
}
func (s *Shell) executeCommandBg(input string) error {
tokens := s.tokenize(input)
if len(tokens) == 0 {
return nil
}
cmdArgs, _ := extractRedirects(tokens)
if len(cmdArgs) == 0 {
return nil
}
path := findExecutable(cmdArgs[0])
if path == "" {
return fmt.Errorf("%s: command not found", cmdArgs[0])
}
cmd := exec.Command(path, cmdArgs[1:]...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
return err
}
pid := cmd.Process.Pid
s.vars["!"] = fmt.Sprintf("%d", pid)
fmt.Fprintf(s.Stderr, "[1] %d\n", pid)
go func() { cmd.Wait() }()
return nil
}
func findExecutable(name string) string {
// Direct path
if strings.ContainsAny(name, "/\\") {
if info, err := os.Stat(name); err == nil && !info.IsDir() {
abs, _ := filepath.Abs(name)
return abs
}
return ""
}
path := os.Getenv("PATH")
for _, dir := range filepath.SplitList(path) {
for _, candidate := range []string{
filepath.Join(dir, name),
filepath.Join(dir, name+".exe"),
filepath.Join(dir, name+".cmd"),
filepath.Join(dir, name+".bat"),
} {
if info, err := os.Stat(candidate); err == nil && !info.IsDir() {
return candidate
}
}
}
return ""
}
// doPipe executes a pipeline where each stage feeds into the next.
func (s *Shell) doPipe(commands []string, bg bool) error {
_ = bg // background pipe support would require goroutines; skip for now
var prevBuf []byte
for i, cmd := range commands {
isLast := i == len(commands)-1
var stdinReader io.Reader
if i == 0 {
stdinReader = s.Stdin
} else {
stdinReader = bytes.NewReader(prevBuf)
}
if isLast {
return s.withIO(stdinReader, nil, nil, func() error {
return s.executeCommand(cmd)
})
}
var buf bytes.Buffer
s.withIO(stdinReader, &buf, nil, func() error {
return s.executeCommand(cmd)
})
prevBuf = buf.Bytes()
}
return nil
}

448
internal/shell/expand.go Normal file
View File

@@ -0,0 +1,448 @@
package shell
import (
"bytes"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
)
func isVarChar(c byte) bool {
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_'
}
// expandWord expands $VAR, ${VAR}, $(...), $((...)) in a single token.
// Quote characters (single/double) are interpreted here and stripped from output.
func (s *Shell) expandWord(word string) string {
// Tilde expansion (only when not quoted)
if strings.HasPrefix(word, "~") {
home := s.GetVar("HOME")
if home == "" {
home = os.Getenv("USERPROFILE")
}
if len(word) == 1 {
return home
}
if word[1] == '/' || word[1] == '\\' {
return home + word[1:]
}
}
var result strings.Builder
inSingle := false
inDouble := false
i := 0
for i < len(word) {
ch := word[i]
switch {
case ch == '\'' && !inDouble:
inSingle = !inSingle
i++
case ch == '"' && !inSingle:
inDouble = !inDouble
i++
case ch == '\\' && !inSingle:
if i+1 < len(word) {
next := word[i+1]
if inDouble {
// In double quotes, only certain chars are escaped
switch next {
case '$', '`', '"', '\\', '\n':
result.WriteByte(next)
i += 2
default:
result.WriteByte('\\')
result.WriteByte(next)
i += 2
}
} else {
result.WriteByte(next)
i += 2
}
} else {
i++
}
case ch == '$' && !inSingle:
i++ // skip $
if i >= len(word) {
result.WriteByte('$')
break
}
switch word[i] {
case '(':
if i+1 < len(word) && word[i+1] == '(' {
// $(( arithmetic ))
j := i + 2
depth := 2
for j < len(word) {
if word[j] == '(' {
depth++
}
if word[j] == ')' {
depth--
if depth == 0 {
j++
break
}
}
j++
}
expr := word[i+2 : j-2]
result.WriteString(strconv.Itoa(s.evalArith(expr)))
i = j
} else {
// $( command substitution )
j := i + 1
depth := 1
for j < len(word) {
if word[j] == '(' {
depth++
}
if word[j] == ')' {
depth--
if depth == 0 {
break
}
}
j++
}
cmd := word[i+1 : j]
out := s.captureCommand(cmd)
result.WriteString(strings.TrimRight(out, "\n"))
i = j + 1
}
case '{':
j := i + 1
depth := 1
for j < len(word) {
if word[j] == '{' {
depth++
}
if word[j] == '}' {
depth--
if depth == 0 {
break
}
}
j++
}
varExpr := word[i+1 : j]
result.WriteString(s.evalVarExpr(varExpr))
i = j + 1
case '?':
result.WriteString(s.vars["?"])
i++
case '$':
result.WriteString(fmt.Sprintf("%d", os.Getpid()))
i++
case '!':
result.WriteString(s.vars["!"])
i++
case '#':
result.WriteString(s.vars["#"])
i++
case '@':
result.WriteString(s.vars["@"])
i++
case '*':
result.WriteString(s.vars["*"])
i++
default:
j := i
for j < len(word) && isVarChar(word[j]) {
j++
}
if j == i {
result.WriteByte('$')
} else {
result.WriteString(s.getVar(word[i:j]))
i = j
}
}
default:
result.WriteByte(ch)
i++
}
}
return result.String()
}
func (s *Shell) getVar(name string) string {
if v, ok := s.vars[name]; ok {
return v
}
return os.Getenv(name)
}
func (s *Shell) evalVarExpr(expr string) string {
// ${#VAR} — string length
if strings.HasPrefix(expr, "#") {
return strconv.Itoa(len(s.getVar(expr[1:])))
}
// ${VAR:-default}
if idx := strings.Index(expr, ":-"); idx >= 0 {
varName := expr[:idx]
if v := s.getVar(varName); v != "" {
return v
}
return s.expandWord(expr[idx+2:])
}
// ${VAR:=default}
if idx := strings.Index(expr, ":="); idx >= 0 {
varName := expr[:idx]
if v := s.getVar(varName); v != "" {
return v
}
expanded := s.expandWord(expr[idx+2:])
s.vars[varName] = expanded
return expanded
}
// ${VAR:+alt}
if idx := strings.Index(expr, ":+"); idx >= 0 {
varName := expr[:idx]
if v := s.getVar(varName); v != "" {
return s.expandWord(expr[idx+2:])
}
return ""
}
// ${VAR%pattern} — strip shortest suffix
if idx := strings.Index(expr, "%"); idx >= 0 {
varName := expr[:idx]
pattern := expr[idx+1:]
v := s.getVar(varName)
if strings.HasSuffix(v, pattern) {
return v[:len(v)-len(pattern)]
}
return v
}
// ${VAR#pattern} — strip shortest prefix
if idx := strings.Index(expr, "#"); idx >= 0 {
varName := expr[:idx]
pattern := expr[idx+1:]
v := s.getVar(varName)
if strings.HasPrefix(v, pattern) {
return v[len(pattern):]
}
return v
}
return s.getVar(expr)
}
// captureCommand runs a command and returns its stdout as a string.
func (s *Shell) captureCommand(cmd string) string {
var buf bytes.Buffer
s.withIO(nil, &buf, nil, func() error {
return s.Execute(cmd)
})
return buf.String()
}
// evalArith evaluates a shell arithmetic expression.
func (s *Shell) evalArith(expr string) int {
expr = strings.TrimSpace(s.expandWord(expr))
// Expand bare variable names (e.g. i+1 → value_of_i + 1)
expr = s.expandArithVars(expr)
return evalArithExpr(expr)
}
// expandArithVars replaces bare identifier names with their shell variable values.
func (s *Shell) expandArithVars(expr string) string {
var result strings.Builder
i := 0
for i < len(expr) {
c := expr[i]
// Identifier start (letter or _), but not a digit
if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_' {
j := i
for j < len(expr) && isVarChar(expr[j]) {
j++
}
varName := expr[i:j]
val := s.getVar(varName)
if val == "" {
val = "0"
}
result.WriteString(val)
i = j
} else {
result.WriteByte(c)
i++
}
}
return result.String()
}
func evalArithExpr(expr string) int {
expr = strings.TrimSpace(expr)
if n, err := strconv.Atoi(expr); err == nil {
return n
}
// Strip outer parens
if strings.HasPrefix(expr, "(") && strings.HasSuffix(expr, ")") {
return evalArithExpr(expr[1 : len(expr)-1])
}
// Operators in precedence order (lowest first so we split on last occurrence)
for _, op := range []string{"+", "-", "*", "/", "%"} {
if idx := findBinaryOp(expr, op); idx >= 0 {
left := evalArithExpr(expr[:idx])
right := evalArithExpr(expr[idx+1:])
switch op {
case "+":
return left + right
case "-":
return left - right
case "*":
return left * right
case "/":
if right == 0 {
return 0
}
return left / right
case "%":
if right == 0 {
return 0
}
return left % right
}
}
}
return 0
}
func findBinaryOp(expr, op string) int {
depth := 0
for i := len(expr) - 1; i >= 0; i-- {
switch expr[i] {
case ')':
depth++
case '(':
depth--
}
if depth != 0 {
continue
}
if expr[i:i+1] == op {
if (op == "-" || op == "+") && i == 0 {
continue
}
return i
}
}
return -1
}
// expandGlob expands glob patterns; returns original if no match.
func (s *Shell) expandGlob(word string) []string {
if !strings.ContainsAny(word, "*?[") {
return []string{word}
}
matches, err := filepath.Glob(word)
if err != nil || len(matches) == 0 {
return []string{word}
}
return matches
}
// tokenize splits input into tokens, expands variables, handles quotes and globs.
func (s *Shell) tokenize(input string) []string {
var rawTokens []string
current := strings.Builder{}
inSingle := false
inDouble := false
parenDepth := 0 // nesting depth inside $(...) or $((...))
pendingDollar := false // true after $ when next char is (
wasQuoted := false
flush := func() {
if current.Len() > 0 {
tok := current.String()
if wasQuoted {
tok = "\x00q" + tok
}
rawTokens = append(rawTokens, tok)
current.Reset()
wasQuoted = false
pendingDollar = false
}
}
for i := 0; i < len(input); i++ {
c := input[i]
switch {
case c == '\'' && !inDouble && parenDepth == 0:
inSingle = !inSingle
wasQuoted = true
current.WriteByte(c)
case c == '"' && !inSingle && parenDepth == 0:
inDouble = !inDouble
wasQuoted = true
current.WriteByte(c)
case c == '$' && !inSingle && i+1 < len(input) && (input[i+1] == '(' || input[i+1] == '{'):
// Mark that the next ( opens a substitution — don't increment depth here
if input[i+1] == '(' {
pendingDollar = true
}
current.WriteByte(c)
case c == '(' && !inSingle && !inDouble && (parenDepth > 0 || pendingDollar):
parenDepth++
pendingDollar = false
current.WriteByte(c)
case c == ')' && !inSingle && !inDouble && parenDepth > 0:
parenDepth--
current.WriteByte(c)
case c == '{' && !inSingle && !inDouble && parenDepth > 0:
parenDepth++
current.WriteByte(c)
case c == '}' && !inSingle && !inDouble && parenDepth > 0:
parenDepth--
current.WriteByte(c)
case (c == ' ' || c == '\t') && !inSingle && !inDouble && parenDepth == 0:
flush()
case c == '#' && !inSingle && !inDouble && parenDepth == 0 && current.Len() == 0:
// Inline comment: # at start of a new token — discard the rest of the input
goto doneTokenizing
default:
pendingDollar = false
current.WriteByte(c)
}
}
doneTokenizing:
flush()
// Handle variable assignment on token[0]: FOO=bar
if len(rawTokens) > 0 {
tok := rawTokens[0]
clean := strings.TrimPrefix(tok, "\x00q")
if eqIdx := strings.Index(clean, "="); eqIdx > 0 {
name := clean[:eqIdx]
if isValidIdentifier(name) && !strings.Contains(clean[:eqIdx], "$") {
value := s.expandWord(clean[eqIdx+1:])
s.vars[name] = value
os.Setenv(name, value)
rawTokens = rawTokens[1:]
if len(rawTokens) == 0 {
return nil
}
}
}
}
var result []string
for _, tok := range rawTokens {
quoted := strings.HasPrefix(tok, "\x00q")
if quoted {
tok = tok[2:]
}
expanded := s.expandWord(tok)
if !quoted && strings.ContainsAny(expanded, "*?[") {
result = append(result, s.expandGlob(expanded)...)
} else {
result = append(result, expanded)
}
}
return result
}

View File

@@ -1,30 +1,54 @@
package shell
import (
"bytes"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
)
// Sentinel errors for control flow
type breakErr struct{ n int }
type continueErr struct{ n int }
type returnErr struct{ code int }
func (e breakErr) Error() string { return fmt.Sprintf("break %d", e.n) }
func (e continueErr) Error() string { return fmt.Sprintf("continue %d", e.n) }
func (e returnErr) Error() string { return fmt.Sprintf("return %d", e.code) }
// exitCodeErr carries a non-zero exit code that sets $? without a message.
// Used by functions, test, false, etc.
type exitCodeErr struct{ code int }
func (e exitCodeErr) Error() string { return "" }
type Shell struct {
vars map[string]string
builtins map[string]func([]string) error
funcs map[string]string // function name → body
lastExit int
Stdin io.Reader
Stdout io.Writer
Stderr io.Writer
args []string
}
func New() *Shell {
s := &Shell{
vars: map[string]string{},
lastExit: 0,
funcs: map[string]string{},
Stdin: os.Stdin,
Stdout: os.Stdout,
Stderr: os.Stderr,
}
s.initBuiltins()
s.vars["SHELL"] = "bash-for-windows"
s.vars["BASH_VERSION"] = "1.0.0"
s.vars["BASH_VERSION"] = "5.2.15(1)-release"
s.vars["?"] = "0"
s.vars["#"] = "0"
s.vars["@"] = ""
s.vars["*"] = ""
s.vars["!"] = ""
if pwd, err := os.Getwd(); err == nil {
s.vars["PWD"] = pwd
}
@@ -34,14 +58,296 @@ func New() *Shell {
return s
}
func (s *Shell) SetArgs(args []string) {
s.args = args
s.vars["#"] = fmt.Sprintf("%d", len(args))
s.vars["@"] = strings.Join(args, " ")
s.vars["*"] = strings.Join(args, " ")
for i, a := range args {
s.vars[fmt.Sprintf("%d", i+1)] = a
}
}
func (s *Shell) GetVar(name string) string {
if v, ok := s.vars[name]; ok {
return v
}
return os.Getenv(name)
}
func (s *Shell) SetVar(name, value string) {
s.vars[name] = value
}
// Execute runs commands from the given input string.
func (s *Shell) Execute(input string) error {
input = strings.TrimSpace(input)
if input == "" {
input = strings.ReplaceAll(input, "\\\n", " ")
blocks := parseBlocks(input)
for _, block := range blocks {
if err := s.executeBlock(block); err != nil {
switch err.(type) {
case breakErr, continueErr, returnErr:
return err
}
s.setExitCode(err)
}
}
return nil
}
// IsIncomplete returns true if the input is an incomplete multi-line construct.
func IsIncomplete(input string) bool {
stmts := splitStatements(input)
depth := 0
inSingle := false
inDouble := false
for _, ch := range input {
switch ch {
case '\'':
if !inDouble {
inSingle = !inSingle
}
case '"':
if !inSingle {
inDouble = !inDouble
}
}
}
if inSingle || inDouble {
return true
}
for _, stmt := range stmts {
w := firstWord(stmt)
switch w {
case "if", "for", "while", "until":
depth++
case "fi", "done", "esac":
depth--
case "{":
depth++
case "}":
if depth > 0 {
depth--
}
}
if isFuncDefStart(stmt) && strings.HasSuffix(strings.TrimSpace(stmt), "{") {
depth++
}
}
return depth > 0
}
// parseBlocks groups statements into logical execution units.
// Multi-line if/for/while/function blocks are gathered into single entries.
func parseBlocks(input string) []string {
stmts := splitStatements(input)
var blocks []string
var current []string
kwDepth := 0 // if/for/while/until → fi/done nesting
inFunc := false
funcKwDepth := 0 // keyword nesting inside a function body
for _, stmt := range stmts {
stmt = strings.TrimSpace(stmt)
if stmt == "" || strings.HasPrefix(stmt, "#") {
continue
}
w := firstWord(stmt)
if !inFunc {
// Detect function definition opening with `{`
if isFuncDefStart(stmt) && strings.Contains(stmt, "{") {
braceIdx := strings.Index(stmt, "{")
// Count keywords after the { on this same line
funcKwDepth = 0
for _, p := range splitStatements(stmt[braceIdx+1:]) {
switch firstWord(p) {
case "if", "for", "while", "until":
funcKwDepth++
case "fi", "done", "esac":
funcKwDepth--
}
}
// If the line also ends with } it's a self-contained function
if strings.HasSuffix(stmt, "}") {
current = append(current, stmt)
blocks = append(blocks, strings.Join(current, "\n"))
current = nil
funcKwDepth = 0
continue
}
inFunc = true
current = append(current, stmt)
continue
}
lines := strings.Split(input, "\n")
for _, line := range lines {
switch w {
case "if", "for", "while", "until":
kwDepth++
}
kwDepth += embeddedKwDepth(stmt)
current = append(current, stmt)
switch w {
case "fi", "done", "esac":
kwDepth--
case "}":
if kwDepth > 0 {
kwDepth--
}
}
if kwDepth <= 0 && len(current) > 0 {
kwDepth = 0
blocks = append(blocks, strings.Join(current, "\n"))
current = nil
}
} else {
// Inside function body — watch for } at funcKwDepth==0
if w == "}" && funcKwDepth <= 0 {
current = append(current, stmt)
blocks = append(blocks, strings.Join(current, "\n"))
current = nil
inFunc = false
funcKwDepth = 0
continue
}
switch w {
case "if", "for", "while", "until":
funcKwDepth++
case "fi", "done", "esac":
funcKwDepth--
}
funcKwDepth += embeddedKwDepth(stmt)
current = append(current, stmt)
}
}
if len(current) > 0 {
blocks = append(blocks, strings.Join(current, "\n"))
}
return blocks
}
// embeddedKwDepth returns the net depth change from keywords that appear
// after do/then/else/elif within a single statement (excluding the first word,
// which is handled separately by the caller).
func embeddedKwDepth(stmt string) int {
words := strings.Fields(stmt)
delta := 0
for j := 1; j < len(words); j++ {
switch words[j-1] {
case "do", "then", "else", "elif":
switch words[j] {
case "if", "for", "while", "until":
delta++
case "fi", "done", "esac":
delta--
}
}
}
return delta
}
// splitStatements splits input on semicolons and newlines, respecting quotes.
func splitStatements(input string) []string {
var result []string
current := strings.Builder{}
inSingle := false
inDouble := false
for i := 0; i < len(input); i++ {
c := input[i]
switch {
case c == '\'' && !inDouble:
inSingle = !inSingle
current.WriteByte(c)
case c == '"' && !inSingle:
inDouble = !inDouble
current.WriteByte(c)
case (c == ';' || c == '\n') && !inSingle && !inDouble:
if s := strings.TrimSpace(current.String()); s != "" {
result = append(result, s)
}
current.Reset()
default:
current.WriteByte(c)
}
}
if s := strings.TrimSpace(current.String()); s != "" {
result = append(result, s)
}
return result
}
func firstWord(s string) string {
fields := strings.Fields(s)
if len(fields) == 0 {
return ""
}
return fields[0]
}
func afterWord(s string) string {
for i, ch := range s {
if ch == ' ' || ch == '\t' {
return strings.TrimSpace(s[i:])
}
}
return ""
}
func isFuncDefStart(stmt string) bool {
if strings.HasPrefix(stmt, "function ") {
return true
}
for i, ch := range stmt {
if ch == ' ' || ch == '\t' {
break
}
if ch == '(' {
name := strings.TrimSpace(stmt[:i])
return isValidIdentifier(name)
}
}
return false
}
func isValidIdentifier(s string) bool {
if len(s) == 0 {
return false
}
for i, c := range s {
if i == 0 {
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_') {
return false
}
} else {
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') {
return false
}
}
}
return true
}
func (s *Shell) executeBlock(block string) error {
block = strings.TrimSpace(block)
if block == "" || strings.HasPrefix(block, "#") {
return nil
}
w := firstWord(block)
switch w {
case "if":
return s.executeIf(block)
case "for":
return s.executeFor(block)
case "while":
return s.executeWhileUntil(block, false)
case "until":
return s.executeWhileUntil(block, true)
}
if isFuncDefStart(block) {
return s.defineFunction(block)
}
for _, line := range strings.Split(block, "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
@@ -54,23 +360,20 @@ func (s *Shell) Execute(input string) error {
}
func (s *Shell) executeLine(line string) error {
// Tokenize the line into its components, handling &&, ||, ;
return s.executeChain(line)
}
// executeChain: parse && / || / ; with left-to-right precedence
func (s *Shell) executeChain(line string) error {
// Strategy: split by ; first (semicolons always separate),
// then by &&/|| within each segment
segments := splitBySemicolon(line)
for _, seg := range segments {
for _, seg := range splitBySemicolon(line) {
seg = strings.TrimSpace(seg)
if seg == "" {
continue
}
if err := s.executeAndOrList(seg); err != nil {
// In ; chains, errors in one command don't stop execution
switch err.(type) {
case breakErr, continueErr, returnErr:
return err
}
s.setExitCode(err)
}
}
@@ -79,7 +382,7 @@ func (s *Shell) executeChain(line string) error {
func splitBySemicolon(line string) []string {
var parts []string
current := ""
current := strings.Builder{}
inSingle := false
inDouble := false
@@ -88,32 +391,30 @@ func splitBySemicolon(line string) []string {
switch {
case c == '\'' && !inDouble:
inSingle = !inSingle
current += string(c)
current.WriteByte(c)
case c == '"' && !inSingle:
inDouble = !inDouble
current += string(c)
current.WriteByte(c)
case c == ';' && !inSingle && !inDouble:
parts = append(parts, current)
current = ""
parts = append(parts, current.String())
current.Reset()
default:
current += string(c)
current.WriteByte(c)
}
}
if current != "" {
parts = append(parts, current)
if current.Len() > 0 {
parts = append(parts, current.String())
}
return parts
}
// executeAndOrList: parse && / || with left-to-right precedence
func (s *Shell) executeAndOrList(line string) error {
type token struct {
type tok struct {
text string
op string // operator BEFORE this token (except first = "")
op string
}
var tokens []token
current := ""
var tokens []tok
current := strings.Builder{}
op := ""
inSingle := false
inDouble := false
@@ -123,326 +424,86 @@ func (s *Shell) executeAndOrList(line string) error {
switch {
case c == '\'' && !inDouble:
inSingle = !inSingle
current += string(c)
current.WriteByte(c)
case c == '"' && !inSingle:
inDouble = !inDouble
current += string(c)
case c == '&' && !inSingle && !inDouble:
if i+1 < len(line) && line[i+1] == '&' {
if current != "" {
tokens = append(tokens, token{current, op})
current = ""
}
current.WriteByte(c)
case c == '&' && !inSingle && !inDouble && i+1 < len(line) && line[i+1] == '&':
tokens = append(tokens, tok{current.String(), op})
current.Reset()
op = "&&"
i++
} else {
current += string(c)
}
case c == '|' && !inSingle && !inDouble:
if i+1 < len(line) && line[i+1] == '|' {
if current != "" {
tokens = append(tokens, token{current, op})
current = ""
}
case c == '|' && !inSingle && !inDouble && i+1 < len(line) && line[i+1] == '|':
tokens = append(tokens, tok{current.String(), op})
current.Reset()
op = "||"
i++
} else {
current += string(c)
}
default:
current += string(c)
current.WriteByte(c)
}
}
if current != "" {
tokens = append(tokens, token{current, op})
}
if len(tokens) == 0 {
return nil
if current.Len() > 0 {
tokens = append(tokens, tok{current.String(), op})
}
var lastErr error
for i, tok := range tokens {
cmd := strings.TrimSpace(tok.text)
for i, t := range tokens {
cmd := strings.TrimSpace(t.text)
if cmd == "" {
continue
}
shouldRun := true
if i > 0 && tok.op == "&&" {
shouldRun = (lastErr == nil)
} else if i > 0 && tok.op == "||" {
shouldRun = (lastErr != nil)
run := i == 0
if i > 0 {
run = (t.op == "&&" && lastErr == nil) || (t.op == "||" && lastErr != nil)
}
if shouldRun {
if run {
err := s.executePipeline(cmd)
lastErr = err
s.setExitCode(err)
}
}
return lastErr
}
func (s *Shell) setExitCode(err error) {
if err != nil {
s.vars["?"] = "1"
} else {
if err == nil {
s.vars["?"] = "0"
}
}
func (s *Shell) executePipeline(input string) error {
input = strings.TrimSpace(input)
if input == "" {
return nil
}
if strings.Contains(input, "|") {
return s.doPipe(input)
}
return s.executeCommand(input)
}
func (s *Shell) executeCommand(input string) error {
parts := s.tokenize(input)
if len(parts) == 0 {
return nil
}
cmdName := parts[0]
args := parts[1:]
if alias, ok := aliases[cmdName]; ok {
fullCmd := alias
if len(args) > 0 {
fullCmd += " " + strings.Join(args, " ")
}
return s.Execute(fullCmd)
}
if builtin, ok := s.builtins[cmdName]; ok {
return builtin(args)
}
return s.executeExternal(cmdName, args)
}
func (s *Shell) tokenize(input string) []string {
var tokens []string
current := ""
inSingle := false
inDouble := false
for i := 0; i < len(input); i++ {
c := input[i]
switch {
case c == '\'' && !inDouble:
inSingle = !inSingle
case c == '"' && !inSingle:
inDouble = !inDouble
case (c == ' ' || c == '\t') && !inSingle && !inDouble:
if current != "" {
tokens = append(tokens, current)
current = ""
}
case c == '=' && !inSingle && !inDouble:
current += string(c)
default:
current += string(c)
}
}
if current != "" {
tokens = append(tokens, current)
}
for i, tok := range tokens {
if strings.HasPrefix(tok, "$") {
key := tok[1:]
if strings.HasPrefix(key, "{") && strings.HasSuffix(key, "}") {
key = key[1 : len(key)-1]
}
if val, ok := s.vars[key]; ok {
tokens[i] = val
} else if val := os.Getenv(key); val != "" {
tokens[i] = val
}
}
}
varAssignIdx := -1
for i, tok := range tokens {
if strings.Contains(tok, "=") && i == 0 {
eqIdx := strings.Index(tok, "=")
if eqIdx > 0 && eqIdx < len(tok)-1 {
name := tok[:eqIdx]
value := tok[eqIdx+1:]
s.vars[name] = value
os.Setenv(name, value)
varAssignIdx = i
}
}
}
if varAssignIdx == 0 && len(tokens) > 1 {
tokens = tokens[1:]
}
return tokens
}
func (s *Shell) executeExternal(cmdName string, args []string) error {
cmdPath := findExecutable(cmdName)
if cmdPath == "" {
return fmt.Errorf("%s: command not found", cmdName)
}
cmd := exec.Command(cmdPath, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func findExecutable(name string) string {
if _, err := os.Stat(name); err == nil {
if info, _ := os.Stat(name); info != nil && !info.IsDir() {
abs, _ := filepath.Abs(name)
return abs
}
}
path := os.Getenv("PATH")
for _, dir := range filepath.SplitList(path) {
fullPath := filepath.Join(dir, name)
info, err := os.Stat(fullPath)
if err == nil && !info.IsDir() {
return fullPath
}
fullPathExe := fullPath + ".exe"
info, err = os.Stat(fullPathExe)
if err == nil && !info.IsDir() {
return fullPathExe
}
}
return ""
}
func (s *Shell) doPipe(input string) error {
commands := strings.Split(input, "|")
type cmdPart struct {
name string
args []string
builtin bool
}
var parts []cmdPart
for _, part := range commands {
part = strings.TrimSpace(part)
if part == "" {
continue
}
tokens := s.tokenize(part)
if len(tokens) == 0 {
continue
}
name := tokens[0]
_, isBuiltin := s.builtins[name]
parts = append(parts, cmdPart{name, tokens[1:], isBuiltin})
}
if len(parts) == 0 {
return nil
}
var prevOutput []byte
for i, p := range parts {
var input []byte
if i > 0 {
input = prevOutput
}
if p.builtin {
output, err := s.captureBuiltin(p.name, p.args, input)
if err != nil {
// Don't return error, let it pass
prevOutput = nil
if i == len(parts)-1 {
return err
}
continue
}
if i == len(parts)-1 {
fmt.Print(string(output))
s.lastExit = 0
} else if ec, ok := err.(exitCodeErr); ok {
s.vars["?"] = fmt.Sprintf("%d", ec.code)
s.lastExit = ec.code
} else {
prevOutput = output
s.vars["?"] = "1"
s.lastExit = 1
}
} else {
cmdPath := findExecutable(p.name)
if cmdPath == "" {
return fmt.Errorf("%s: command not found", p.name)
}
cmd := exec.Command(cmdPath, p.args...)
cmd.Stderr = os.Stderr
if i == 0 && len(input) == 0 {
cmd.Stdin = os.Stdin
} else if len(input) > 0 {
stdin, _ := cmd.StdinPipe()
go func() {
stdin.Write(input)
stdin.Close()
}()
}
if i == len(parts)-1 {
cmd.Stdout = os.Stdout
if err := cmd.Run(); err != nil {
return err
}
} else {
output, err := cmd.Output()
if err != nil {
return err
}
prevOutput = output
}
}
}
return nil
}
func (s *Shell) captureBuiltin(name string, args []string, input []byte) ([]byte, error) {
oldStdout := os.Stdout
oldStdin := os.Stdin
r, w, _ := os.Pipe()
os.Stdout = w
if len(input) > 0 {
ir, iw, _ := os.Pipe()
iw.Write(input)
iw.Close()
os.Stdin = ir
defer ir.Close()
// BuiltinNames returns a sorted list of all registered builtin names (for tab completion).
func (s *Shell) BuiltinNames() []string {
names := make([]string, 0, len(s.builtins)+len(s.funcs))
for k := range s.builtins {
names = append(names, k)
}
fn := s.builtins[name]
err := fn(args)
w.Close()
os.Stdout = oldStdout
os.Stdin = oldStdin
var buf bytes.Buffer
io.Copy(&buf, r)
r.Close()
return buf.Bytes(), err
for k := range s.funcs {
names = append(names, k)
}
return names
}
// withIO temporarily swaps stdin/stdout/stderr, runs fn, then restores.
// Pass nil to leave the corresponding stream unchanged.
func (s *Shell) withIO(stdin io.Reader, stdout io.Writer, stderr io.Writer, fn func() error) error {
oldIn, oldOut, oldErr := s.Stdin, s.Stdout, s.Stderr
if stdin != nil {
s.Stdin = stdin
}
if stdout != nil {
s.Stdout = stdout
}
if stderr != nil {
s.Stderr = stderr
}
err := fn()
s.Stdin, s.Stdout, s.Stderr = oldIn, oldOut, oldErr
return err
}